Make OrderBy(order interface{}, args ...interface{})

Instead of forcing OrderBy to be a string this PR allows
OrderBy to pass in either a builder.Cond or a string-args pair.

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2022-05-29 10:43:31 +01:00
parent 26d291bbc3
commit bccf0686ef
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
8 changed files with 68 additions and 42 deletions

View File

@ -380,7 +380,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error {
seq = 0 seq = 0
} }
} }
var colName = strings.Trim(parts[0], `"`) colName := strings.Trim(parts[0], `"`)
if col := table.GetColumn(colName); col != nil { if col := table.GetColumn(colName); col != nil {
col.Indexes[index.Name] = index.Type col.Indexes[index.Name] = index.Type
} else { } else {
@ -502,9 +502,9 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
} }
} }
var dstTableName = dstTable.Name dstTableName := dstTable.Name
var quoter = dstDialect.Quoter().Quote quoter := dstDialect.Quoter().Quote
var quotedDstTableName = quoter(dstTable.Name) quotedDstTableName := quoter(dstTable.Name)
if dstDialect.URI().Schema != "" { if dstDialect.URI().Schema != "" {
dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name) dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name)) quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name))
@ -1006,10 +1006,10 @@ func (engine *Engine) Asc(colNames ...string) *Session {
} }
// OrderBy will generate "ORDER BY order" // OrderBy will generate "ORDER BY order"
func (engine *Engine) OrderBy(order string) *Session { func (engine *Engine) OrderBy(order interface{}, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.isAutoClose = true session.isAutoClose = true
return session.OrderBy(order) return session.OrderBy(order, args...)
} }
// Prepare enables prepare statement // Prepare enables prepare statement

View File

@ -54,7 +54,7 @@ type Interface interface {
Nullable(...string) *Session Nullable(...string) *Session
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
Omit(columns ...string) *Session Omit(columns ...string) *Session
OrderBy(order string) *Session OrderBy(order interface{}, args ...interface{}) *Session
Ping() error Ping() error
Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error) Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error)
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)

View File

@ -28,7 +28,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -83,7 +83,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
var sumStrs = make([]string, 0, len(columns)) sumStrs := make([]string, 0, len(columns))
for _, colName := range columns { for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.quote(colName) colName = statement.quote(colName)
@ -119,7 +119,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
@ -180,7 +180,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
} }
} }
var selectSQL = statement.SelectStr selectSQL := statement.SelectStr
if len(selectSQL) <= 0 { if len(selectSQL) <= 0 {
if statement.IsDistinct { if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
@ -211,8 +211,8 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
func (statement *Statement) fromBuilder() *strings.Builder { func (statement *Statement) fromBuilder() *strings.Builder {
var builder strings.Builder var builder strings.Builder
var quote = statement.quote quote := statement.quote
var dialect = statement.dialect dialect := statement.dialect
builder.WriteString(" FROM ") builder.WriteString(" FROM ")
@ -290,8 +290,9 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
var orderStr string var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 { if needOrderBy && statement.OrderStr != "" {
orderStr = fmt.Sprintf(" ORDER BY %s", statement.OrderStr) orderStr = fmt.Sprintf(" ORDER BY %s", orderStr)
condArgs = append(condArgs, statement.OrderArgs...)
} }
var groupStr string var groupStr string
@ -321,6 +322,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
if needOrderBy && statement.OrderStr != "" { if needOrderBy && statement.OrderStr != "" {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
condArgs = append(condArgs, statement.OrderArgs...)
} }
if needLimit { if needLimit {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
@ -436,7 +438,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
var columnStr = statement.ColumnStr() columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 { if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {

View File

@ -44,6 +44,7 @@ type Statement struct {
LimitN *int LimitN *int
idParam schemas.PK idParam schemas.PK
OrderStr string OrderStr string
OrderArgs []interface{}
JoinStr string JoinStr string
joinArgs []interface{} joinArgs []interface{}
GroupByStr string GroupByStr string
@ -130,6 +131,7 @@ func (statement *Statement) Reset() {
statement.Start = 0 statement.Start = 0
statement.LimitN = nil statement.LimitN = nil
statement.OrderStr = "" statement.OrderStr = ""
statement.OrderArgs = nil
statement.UseCascade = true statement.UseCascade = true
statement.JoinStr = "" statement.JoinStr = ""
statement.joinArgs = make([]interface{}, 0) statement.joinArgs = make([]interface{}, 0)
@ -455,11 +457,28 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
} }
// OrderBy generate "Order By order" statement // OrderBy generate "Order By order" statement
func (statement *Statement) OrderBy(order string) *Statement { func (statement *Statement) OrderBy(order interface{}, args ...interface{}) *Statement {
var rawOrder string
switch order.(type) {
case (*builder.Builder):
var err error
rawOrder, args, err = order.(*builder.Builder).ToSQL()
if err != nil {
statement.LastError = err
}
case string:
rawOrder = order.(string)
statement.RawParams = args
default:
statement.LastError = ErrUnSupportedSQLType
return statement
}
if len(statement.OrderStr) > 0 { if len(statement.OrderStr) > 0 {
statement.OrderStr += ", " statement.OrderStr += ", "
} }
statement.OrderStr += statement.ReplaceQuote(order) statement.OrderStr += statement.ReplaceQuote(rawOrder)
statement.OrderArgs = append(statement.OrderArgs, args...)
return statement return statement
} }

View File

@ -275,8 +275,8 @@ func (session *Session) Limit(limit int, start ...int) *Session {
// OrderBy provide order by query condition, the input parameter is the content // OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement. // after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session { func (session *Session) OrderBy(order interface{}, args ...interface{}) *Session {
session.statement.OrderBy(order) session.statement.OrderBy(order, args...)
return session return session
} }

View File

@ -129,9 +129,9 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
return 0, ErrNeedDeletedCond return 0, ErrNeedDeletedCond
} }
var tableNameNoQuote = session.statement.TableName() tableNameNoQuote := session.statement.TableName()
var tableName = session.engine.Quote(tableNameNoQuote) tableName := session.engine.Quote(tableNameNoQuote)
var table = session.statement.RefTable table := session.statement.RefTable
var deleteSQL string var deleteSQL string
if len(condSQL) > 0 { if len(condSQL) > 0 {
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL)
@ -142,6 +142,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
var orderSQL string var orderSQL string
if len(session.statement.OrderStr) > 0 { if len(session.statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
condArgs = append(condArgs, session.statement.OrderArgs...)
} }
if pLimitN != nil && *pLimitN > 0 { if pLimitN != nil && *pLimitN > 0 {
limitNValue := *pLimitN limitNValue := *pLimitN
@ -224,7 +225,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
} }
condArgs[0] = val condArgs[0] = val
var colName = deletedColumn.Name colName := deletedColumn.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)

View File

@ -63,6 +63,9 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
if session.statement.OrderStr != "" { if session.statement.OrderStr != "" {
session.statement.OrderStr = "" session.statement.OrderStr = ""
} }
if session.statement.OrderArgs != nil {
session.statement.OrderArgs = nil
}
if session.statement.LimitN != nil { if session.statement.LimitN != nil {
session.statement.LimitN = nil session.statement.LimitN = nil
} }
@ -85,15 +88,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
var isSlice = sliceValue.Kind() == reflect.Slice isSlice := sliceValue.Kind() == reflect.Slice
var isMap = sliceValue.Kind() == reflect.Map isMap := sliceValue.Kind() == reflect.Map
if !isSlice && !isMap { if !isSlice && !isMap {
return errors.New("needs a pointer to a slice or a map") return errors.New("needs a pointer to a slice or a map")
} }
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var tp = tpStruct tp := tpStruct
if session.statement.RefTable == nil { if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
@ -190,7 +193,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
return err return err
} }
var newElemFunc = func(fields []string) reflect.Value { newElemFunc := func(fields []string) reflect.Value {
return utils.New(elemType, len(fields), len(fields)) return utils.New(elemType, len(fields), len(fields))
} }
@ -235,7 +238,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
tb, err := session.engine.tagParser.ParseWithCache(newValue) tb, err := session.engine.tagParser.ParseWithCache(newValue)
if err != nil { if err != nil {
return err return err
@ -249,7 +252,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
} }
for rows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
switch elemType.Kind() { switch elemType.Kind() {
@ -310,7 +313,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")
return ErrCacheFailed return ErrCacheFailed
} }
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -342,7 +345,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ididxes := make(map[string]int) ididxes := make(map[string]int)
var ides []schemas.PK var ides []schemas.PK
var temps = make([]interface{}, len(ids)) temps := make([]interface{}, len(ids))
for idx, id := range ids { for idx, id := range ids {
sid, err := id.ToString() sid, err := id.ToString()
@ -457,7 +460,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
} }
} else if sliceValue.Kind() == reflect.Map { } else if sliceValue.Kind() == reflect.Map {
var key = ids[j] key := ids[j]
keyType := sliceValue.Type().Key() keyType := sliceValue.Type().Key()
keyValue := reflect.New(keyType) keyValue := reflect.New(keyType)
var ikey interface{} var ikey interface{}

View File

@ -60,7 +60,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
ids = make([]schemas.PK, 0) ids = make([]schemas.PK, 0)
for rows.Next() { for rows.Next() {
var res = make([]string, len(table.PrimaryKeys)) res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res) err = rows.ScanSlice(&res)
if err != nil { if err != nil {
return err return err
@ -176,8 +176,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
// -- // --
var err error var err error
var isMap = t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
@ -226,7 +226,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, val) args = append(args, val)
} }
var colName = col.Name colName := col.Name
if isStruct { if isStruct {
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
@ -279,7 +279,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condBeanIsStruct := false condBeanIsStruct := false
if len(condiBean) > 0 { if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok { if c, ok := condiBean[0].(map[string]interface{}); ok {
var eq = make(builder.Eq) eq := make(builder.Eq)
for k, v := range c { for k, v := range c {
eq[session.engine.Quote(k)] = v eq[session.engine.Quote(k)] = v
} }
@ -357,10 +357,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
if st.OrderStr != "" { if st.OrderStr != "" {
condSQL += fmt.Sprintf(" ORDER BY %v", st.OrderStr) condSQL += fmt.Sprintf(" ORDER BY %s", st.OrderStr)
condArgs = append(condArgs, st.OrderArgs...)
} }
var tableName = session.statement.TableName() tableName := session.statement.TableName()
// TODO: Oracle support needed // TODO: Oracle support needed
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
@ -410,7 +411,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
var tableAlias = session.engine.Quote(tableName) tableAlias := session.engine.Quote(tableName)
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType { switch session.engine.dialect.URI().DBType {
@ -535,7 +536,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
} }
args = append(args, val) args = append(args, val)
var colName = col.Name colName := col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)