diff --git a/engine.go b/engine.go index b7dcf5a2..81cfc7a9 100644 --- a/engine.go +++ b/engine.go @@ -380,7 +380,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error { seq = 0 } } - var colName = strings.Trim(parts[0], `"`) + colName := strings.Trim(parts[0], `"`) if col := table.GetColumn(colName); col != nil { col.Indexes[index.Name] = index.Type } else { @@ -502,9 +502,9 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w } } - var dstTableName = dstTable.Name - var quoter = dstDialect.Quoter().Quote - var quotedDstTableName = quoter(dstTable.Name) + dstTableName := dstTable.Name + quoter := dstDialect.Quoter().Quote + quotedDstTableName := quoter(dstTable.Name) if dstDialect.URI().Schema != "" { dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name) quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name)) @@ -1006,10 +1006,10 @@ func (engine *Engine) Asc(colNames ...string) *Session { } // OrderBy will generate "ORDER BY order" -func (engine *Engine) OrderBy(order string) *Session { +func (engine *Engine) OrderBy(order interface{}, args ...interface{}) *Session { session := engine.NewSession() session.isAutoClose = true - return session.OrderBy(order) + return session.OrderBy(order, args...) } // Prepare enables prepare statement diff --git a/interface.go b/interface.go index b9e88505..55ffebe4 100644 --- a/interface.go +++ b/interface.go @@ -54,7 +54,7 @@ type Interface interface { Nullable(...string) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Omit(columns ...string) *Session - OrderBy(order string) *Session + OrderBy(order interface{}, args ...interface{}) *Session Ping() error Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) diff --git a/internal/statements/query.go b/internal/statements/query.go index 8b383866..0e1e8f71 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -28,7 +28,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, ErrTableNotFound } - var columnStr = statement.ColumnStr() + columnStr := statement.ColumnStr() if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { @@ -83,7 +83,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri return "", nil, err } - var sumStrs = make([]string, 0, len(columns)) + sumStrs := make([]string, 0, len(columns)) for _, colName := range columns { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { colName = statement.quote(colName) @@ -119,7 +119,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } } - var columnStr = statement.ColumnStr() + columnStr := statement.ColumnStr() if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { @@ -180,7 +180,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa } } - var selectSQL = statement.SelectStr + selectSQL := statement.SelectStr if len(selectSQL) <= 0 { if statement.IsDistinct { selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) @@ -211,8 +211,8 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa func (statement *Statement) fromBuilder() *strings.Builder { var builder strings.Builder - var quote = statement.quote - var dialect = statement.dialect + quote := statement.quote + dialect := statement.dialect builder.WriteString(" FROM ") @@ -290,8 +290,9 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } var orderStr string - if needOrderBy && len(statement.OrderStr) > 0 { - orderStr = fmt.Sprintf(" ORDER BY %s", statement.OrderStr) + if needOrderBy && statement.OrderStr != "" { + orderStr = fmt.Sprintf(" ORDER BY %s", orderStr) + condArgs = append(condArgs, statement.OrderArgs...) } var groupStr string @@ -321,6 +322,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } if needOrderBy && statement.OrderStr != "" { fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) + condArgs = append(condArgs, statement.OrderArgs...) } if needLimit { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { @@ -436,7 +438,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa return "", nil, ErrTableNotFound } - var columnStr = statement.ColumnStr() + columnStr := statement.ColumnStr() if len(statement.SelectStr) > 0 { columnStr = statement.SelectStr } else { diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 3069561e..cad97a3c 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -44,6 +44,7 @@ type Statement struct { LimitN *int idParam schemas.PK OrderStr string + OrderArgs []interface{} JoinStr string joinArgs []interface{} GroupByStr string @@ -130,6 +131,7 @@ func (statement *Statement) Reset() { statement.Start = 0 statement.LimitN = nil statement.OrderStr = "" + statement.OrderArgs = nil statement.UseCascade = true statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) @@ -455,11 +457,28 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement { } // OrderBy generate "Order By order" statement -func (statement *Statement) OrderBy(order string) *Statement { +func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement { + var rawOrder string + switch order.(type) { + case (*builder.Builder): + var err error + rawOrder, args, err = order.(*builder.Builder).ToSQL() + if err != nil { + statement.LastError = err + } + case string: + rawOrder = order.(string) + statement.RawParams = args + default: + statement.LastError = ErrUnSupportedSQLType + return statement + } + if len(statement.OrderStr) > 0 { statement.OrderStr += ", " } - statement.OrderStr += statement.ReplaceQuote(order) + statement.OrderStr += statement.ReplaceQuote(rawOrder) + statement.OrderArgs = append(statement.OrderArgs, args...) return statement } diff --git a/session.go b/session.go index 64b98bfe..388678cd 100644 --- a/session.go +++ b/session.go @@ -275,8 +275,8 @@ func (session *Session) Limit(limit int, start ...int) *Session { // OrderBy provide order by query condition, the input parameter is the content // after order by on a sql statement. -func (session *Session) OrderBy(order string) *Session { - session.statement.OrderBy(order) +func (session *Session) OrderBy(order interface{}, args ...interface{}) *Session { + session.statement.OrderBy(order, args...) return session } diff --git a/session_delete.go b/session_delete.go index a0f420b1..4cfec66c 100644 --- a/session_delete.go +++ b/session_delete.go @@ -129,9 +129,9 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) { return 0, ErrNeedDeletedCond } - var tableNameNoQuote = session.statement.TableName() - var tableName = session.engine.Quote(tableNameNoQuote) - var table = session.statement.RefTable + tableNameNoQuote := session.statement.TableName() + tableName := session.engine.Quote(tableNameNoQuote) + table := session.statement.RefTable var deleteSQL string if len(condSQL) > 0 { deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) @@ -142,6 +142,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) { var orderSQL string if len(session.statement.OrderStr) > 0 { orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) + condArgs = append(condArgs, session.statement.OrderArgs...) } if pLimitN != nil && *pLimitN > 0 { limitNValue := *pLimitN @@ -224,7 +225,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) { } condArgs[0] = val - var colName = deletedColumn.Name + colName := deletedColumn.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t) diff --git a/session_find.go b/session_find.go index caf79ee3..7685c1d1 100644 --- a/session_find.go +++ b/session_find.go @@ -63,6 +63,9 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte if session.statement.OrderStr != "" { session.statement.OrderStr = "" } + if session.statement.OrderArgs != nil { + session.statement.OrderArgs = nil + } if session.statement.LimitN != nil { session.statement.LimitN = nil } @@ -85,15 +88,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - var isSlice = sliceValue.Kind() == reflect.Slice - var isMap = sliceValue.Kind() == reflect.Map + isSlice := sliceValue.Kind() == reflect.Slice + isMap := sliceValue.Kind() == reflect.Map if !isSlice && !isMap { return errors.New("needs a pointer to a slice or a map") } sliceElementType := sliceValue.Type().Elem() - var tp = tpStruct + tp := tpStruct if session.statement.RefTable == nil { if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { @@ -190,7 +193,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } - var newElemFunc = func(fields []string) reflect.Value { + newElemFunc := func(fields []string) reflect.Value { return utils.New(elemType, len(fields), len(fields)) } @@ -235,7 +238,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } if elemType.Kind() == reflect.Struct { - var newValue = newElemFunc(fields) + newValue := newElemFunc(fields) tb, err := session.engine.tagParser.ParseWithCache(newValue) if err != nil { return err @@ -249,7 +252,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { - var newValue = newElemFunc(fields) + newValue := newElemFunc(fields) bean := newValue.Interface() switch elemType.Kind() { @@ -310,7 +313,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") return ErrCacheFailed } - var res = make([]string, len(table.PrimaryKeys)) + res := make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err @@ -342,7 +345,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ididxes := make(map[string]int) var ides []schemas.PK - var temps = make([]interface{}, len(ids)) + temps := make([]interface{}, len(ids)) for idx, id := range ids { sid, err := id.ToString() @@ -457,7 +460,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) } } else if sliceValue.Kind() == reflect.Map { - var key = ids[j] + key := ids[j] keyType := sliceValue.Type().Key() keyValue := reflect.New(keyType) var ikey interface{} diff --git a/session_update.go b/session_update.go index fefbee90..e19ccaf0 100644 --- a/session_update.go +++ b/session_update.go @@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { - var res = make([]string, len(table.PrimaryKeys)) + res := make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { return err @@ -176,8 +176,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // -- var err error - var isMap = t.Kind() == reflect.Map - var isStruct = t.Kind() == reflect.Struct + isMap := t.Kind() == reflect.Map + isStruct := t.Kind() == reflect.Struct if isStruct { if err := session.statement.SetRefBean(bean); err != nil { return 0, err @@ -226,7 +226,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, val) } - var colName = col.Name + colName := col.Name if isStruct { session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) @@ -279,7 +279,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { - var eq = make(builder.Eq) + eq := make(builder.Eq) for k, v := range c { eq[session.engine.Quote(k)] = v } @@ -357,10 +357,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if st.OrderStr != "" { - condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr) + condSQL += fmt.Sprintf(" ORDER BY %s", st.OrderStr) + condArgs = append(condArgs, st.OrderArgs...) } - var tableName = session.statement.TableName() + tableName := session.statement.TableName() // TODO: Oracle support needed var top string if st.LimitN != nil { @@ -410,7 +411,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - var tableAlias = session.engine.Quote(tableName) + tableAlias := session.engine.Quote(tableName) var fromSQL string if session.statement.TableAlias != "" { switch session.engine.dialect.URI().DBType { @@ -535,7 +536,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } args = append(args, val) - var colName = col.Name + colName := col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { col := table.GetColumn(colName) setColumnTime(bean, col, t)