diff --git a/engine.go b/engine.go index 31a6b688..459c2f44 100644 --- a/engine.go +++ b/engine.go @@ -1434,9 +1434,9 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf return result, nil } -func (engine *Engine) IndexHint(op, indexerOrColName string) *Session { +func (engine *Engine) IndexHint(op, forType, indexerOrColName string) *Session { session := engine.NewSession() session.isAutoClose = true - session.statement.LastError = session.statement.IndexHint(op, indexerOrColName) + session.statement.LastError = session.statement.IndexHint(op, forType, indexerOrColName) return session } diff --git a/internal/statements/index.go b/internal/statements/index.go index 8f54d242..5c1420eb 100644 --- a/internal/statements/index.go +++ b/internal/statements/index.go @@ -19,10 +19,11 @@ func (e ErrInvalidIndexHintOperator) Error() string { return "invalid index hint operator: " + e.Op } -func (statement *Statement) IndexHint(op, indexName string) error { +func (statement *Statement) IndexHint(op, forType, indexName string) error { op = strings.ToUpper(op) statement.indexHints = append(statement.indexHints, indexHint{ op: op, + forType: forType, indexName: indexName, }) return nil @@ -46,7 +47,16 @@ func (statement *Statement) writeIndexHintsMySQL(w *builder.BytesWriter) error { if hint.op != "USE" && hint.op != "FORCE" && hint.op != "IGNORE" { return ErrInvalidIndexHintOperator{Op: hint.op} } - if err := statement.writeStrings(" ", hint.op, " INDEX(", hint.indexName, ")")(w); err != nil { + if err := statement.writeStrings(" ", hint.op, " INDEX")(w); err != nil { + return err + } + if hint.forType != "" { + if err := statement.writeStrings(" FOR ", hint.forType)(w); err != nil { + return err + } + } + + if err := statement.writeStrings("(", hint.indexName, ")")(w); err != nil { return err } } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 31073ad1..dd4024b5 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -43,6 +43,7 @@ type join struct { type indexHint struct { op string + forType string indexName string } diff --git a/session_schema.go b/session_schema.go index f7ab2657..4bb0b858 100644 --- a/session_schema.go +++ b/session_schema.go @@ -312,7 +312,7 @@ func (session *Session) Import(r io.Reader) ([]sql.Result, error) { return results, lastError } -func (session *Session) IndexHint(op, indexerOrColName string) *Session { - session.statement.IndexHint(op, indexerOrColName) +func (session *Session) IndexHint(op, forType, indexerOrColName string) *Session { + session.statement.IndexHint(op, forType, indexerOrColName) return session } diff --git a/tests/session_test.go b/tests/session_test.go index aeba1141..3286ab2e 100644 --- a/tests/session_test.go +++ b/tests/session_test.go @@ -70,6 +70,12 @@ func TestIndexHint(t *testing.T) { return } - _, err := testEngine.Table("userinfo").IndexHint("USE", "UQE_userinfo_username").Get(new(Userinfo)) + _, err := testEngine.Table("userinfo").IndexHint("USE", "", "UQE_userinfo_username").Get(new(Userinfo)) + assert.NoError(t, err) + + _, err = testEngine.Table("userinfo").IndexHint("USE", "ORDER BY", "UQE_userinfo_username").Get(new(Userinfo)) + assert.NoError(t, err) + + _, err = testEngine.Table("userinfo").IndexHint("USE", "GROUP BY", "UQE_userinfo_username").Get(new(Userinfo)) assert.NoError(t, err) }