Add missing index hint parameter

This commit is contained in:
Lunny Xiao 2023-12-19 12:30:36 +08:00
parent b571d91858
commit 1a6eb5a7bd
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 24 additions and 7 deletions

View File

@ -1434,9 +1434,9 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf
return result, nil
}
func (engine *Engine) IndexHint(op, indexerOrColName string) *Session {
func (engine *Engine) IndexHint(op, forType, indexerOrColName string) *Session {
session := engine.NewSession()
session.isAutoClose = true
session.statement.LastError = session.statement.IndexHint(op, indexerOrColName)
session.statement.LastError = session.statement.IndexHint(op, forType, indexerOrColName)
return session
}

View File

@ -19,10 +19,11 @@ func (e ErrInvalidIndexHintOperator) Error() string {
return "invalid index hint operator: " + e.Op
}
func (statement *Statement) IndexHint(op, indexName string) error {
func (statement *Statement) IndexHint(op, forType, indexName string) error {
op = strings.ToUpper(op)
statement.indexHints = append(statement.indexHints, indexHint{
op: op,
forType: forType,
indexName: indexName,
})
return nil
@ -46,7 +47,16 @@ func (statement *Statement) writeIndexHintsMySQL(w *builder.BytesWriter) error {
if hint.op != "USE" && hint.op != "FORCE" && hint.op != "IGNORE" {
return ErrInvalidIndexHintOperator{Op: hint.op}
}
if err := statement.writeStrings(" ", hint.op, " INDEX(", hint.indexName, ")")(w); err != nil {
if err := statement.writeStrings(" ", hint.op, " INDEX")(w); err != nil {
return err
}
if hint.forType != "" {
if err := statement.writeStrings(" FOR ", hint.forType)(w); err != nil {
return err
}
}
if err := statement.writeStrings("(", hint.indexName, ")")(w); err != nil {
return err
}
}

View File

@ -43,6 +43,7 @@ type join struct {
type indexHint struct {
op string
forType string
indexName string
}

View File

@ -312,7 +312,7 @@ func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
return results, lastError
}
func (session *Session) IndexHint(op, indexerOrColName string) *Session {
session.statement.IndexHint(op, indexerOrColName)
func (session *Session) IndexHint(op, forType, indexerOrColName string) *Session {
session.statement.IndexHint(op, forType, indexerOrColName)
return session
}

View File

@ -70,6 +70,12 @@ func TestIndexHint(t *testing.T) {
return
}
_, err := testEngine.Table("userinfo").IndexHint("USE", "UQE_userinfo_username").Get(new(Userinfo))
_, err := testEngine.Table("userinfo").IndexHint("USE", "", "UQE_userinfo_username").Get(new(Userinfo))
assert.NoError(t, err)
_, err = testEngine.Table("userinfo").IndexHint("USE", "ORDER BY", "UQE_userinfo_username").Get(new(Userinfo))
assert.NoError(t, err)
_, err = testEngine.Table("userinfo").IndexHint("USE", "GROUP BY", "UQE_userinfo_username").Get(new(Userinfo))
assert.NoError(t, err)
}