// Copyright 2015 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package statements import ( "database/sql/driver" "errors" "fmt" "math/big" "reflect" "strings" "time" "xorm.io/builder" "xorm.io/xorm/v2/contexts" "xorm.io/xorm/v2/dialects" "xorm.io/xorm/v2/internal/convert" "xorm.io/xorm/v2/internal/json" "xorm.io/xorm/v2/internal/utils" "xorm.io/xorm/v2/schemas" "xorm.io/xorm/v2/tags" ) var ( // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") // ErrUnSupportedSQLType parameter of SQL is not supported ErrUnSupportedSQLType = errors.New("Unsupported sql type") // ErrUnSupportedType unsupported error ErrUnSupportedType = errors.New("Unsupported type error") // ErrTableNotFound table not found error ErrTableNotFound = errors.New("Table not found") ) type join struct { op string table interface{} condition interface{} args []interface{} } // Statement save all the sql info for executing SQL type Statement struct { RefTable *schemas.Table dialect dialects.Dialect defaultTimeZone *time.Location tagParser *tags.Parser Start int LimitN *int idParam schemas.PK orderBy []orderBy joins []join GroupByStr string HavingStr string SelectStr string useAllCols bool AltTableName string tableName string RawSQL string RawParams []interface{} UseCascade bool UseAutoJoin bool StoreEngine string Charset string UseAutoTime bool NoAutoCondition bool IsDistinct bool IsForUpdate bool TableAlias string allUseBool bool CheckVersion bool unscoped bool ColumnMap columnMap OmitColumnMap columnMap MustColumnMap map[string]bool NullableMap map[string]bool IncrColumns exprParams DecrColumns exprParams ExprColumns exprParams cond builder.Cond BufferSize int Context contexts.ContextCache LastError error } // NewStatement creates a new statement func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement { statement := &Statement{ dialect: dialect, tagParser: tagParser, defaultTimeZone: defaultTimeZone, } statement.Reset() return statement } // SetTableName set table name func (statement *Statement) SetTableName(tableName string) { statement.tableName = tableName } // GenRawSQL generates correct raw sql func (statement *Statement) GenRawSQL() string { return statement.ReplaceQuote(statement.RawSQL) } // ReplaceQuote replace sql key words with quote func (statement *Statement) ReplaceQuote(sql string) string { if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || statement.dialect.URI().DBType == schemas.SQLITE { return sql } return statement.dialect.Quoter().Replace(sql) } // SetContextCache sets context cache func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { statement.Context = ctxCache } // Reset reset all the statement's fields func (statement *Statement) Reset() { statement.RefTable = nil statement.Start = 0 statement.LimitN = nil statement.ResetOrderBy() statement.UseCascade = true statement.joins = nil statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnMap = columnMap{} statement.OmitColumnMap = columnMap{} statement.AltTableName = "" statement.tableName = "" statement.idParam = nil statement.RawSQL = "" statement.RawParams = make([]interface{}, 0) statement.UseAutoTime = true statement.NoAutoCondition = false statement.IsDistinct = false statement.IsForUpdate = false statement.TableAlias = "" statement.SelectStr = "" statement.allUseBool = false statement.useAllCols = false statement.MustColumnMap = make(map[string]bool) statement.NullableMap = make(map[string]bool) statement.CheckVersion = true statement.unscoped = false statement.IncrColumns = exprParams{} statement.DecrColumns = exprParams{} statement.ExprColumns = exprParams{} statement.cond = builder.NewCond() statement.BufferSize = 0 statement.Context = nil statement.LastError = nil } // SQL adds raw sql statement func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { switch t := query.(type) { case (*builder.Builder): var err error statement.RawSQL, statement.RawParams, err = t.ToSQL() if err != nil { statement.LastError = err } case string: statement.RawSQL = t statement.RawParams = args default: statement.LastError = ErrUnSupportedSQLType } return statement } func (statement *Statement) quote(s string) string { return statement.dialect.Quoter().Quote(s) } // SetRefValue set ref value func (statement *Statement) SetRefValue(v reflect.Value) error { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) if err != nil { return err } statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true) return nil } func rValue(bean interface{}) reflect.Value { return reflect.Indirect(reflect.ValueOf(bean)) } // SetRefBean set ref bean func (statement *Statement) SetRefBean(bean interface{}) error { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) if err != nil { return err } statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true) return nil } func (statement *Statement) NeedTableName() bool { return len(statement.joins) > 0 } // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { statement.IncrColumns.Add(column, arg[0]) } else { statement.IncrColumns.Add(column, 1) } return statement } // Decr Generate "Update ... Set column = column - arg" statement func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { statement.DecrColumns.Add(column, arg[0]) } else { statement.DecrColumns.Add(column, 1) } return statement } // SetExpr Generate "Update ... Set column = {expression}" statement func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { if e, ok := expression.(string); ok { statement.ExprColumns.Add(column, statement.dialect.Quoter().Replace(e)) } else { statement.ExprColumns.Add(column, expression) } return statement } // ForUpdate generates "SELECT ... FOR UPDATE" statement func (statement *Statement) ForUpdate() *Statement { statement.IsForUpdate = true return statement } // Nullable Update use only: update columns to null when value is nullable and zero-value func (statement *Statement) Nullable(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { statement.NullableMap[strings.ToLower(nc)] = true } } // Top generate LIMIT limit statement func (statement *Statement) Top(limit int) *Statement { statement.Limit(limit) return statement } // Limit generate LIMIT start, limit statement func (statement *Statement) Limit(limit int, start ...int) *Statement { statement.LimitN = &limit if len(start) > 0 { statement.Start = start[0] } return statement } // SetTable tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) SetTable(tableNameOrBean interface{}) error { v := rValue(tableNameOrBean) t := v.Type() if t.Kind() == reflect.Struct { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(v) if err != nil { return err } } statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true) return nil } // GroupBy generate "Group By keys" statement func (statement *Statement) GroupBy(keys string) *Statement { statement.GroupByStr = statement.ReplaceQuote(keys) return statement } func (statement *Statement) writeGroupBy(w *builder.BytesWriter) error { if statement.GroupByStr == "" { return nil } _, err := fmt.Fprint(w, " GROUP BY ", statement.GroupByStr) return err } // Having generate "Having conditions" statement func (statement *Statement) Having(conditions string) *Statement { statement.HavingStr = conditions return statement } func (statement *Statement) writeHaving(w *builder.BytesWriter) error { if statement.HavingStr == "" { return nil } _, err := fmt.Fprint(w, " HAVING ", statement.ReplaceQuote(statement.HavingStr)) return err } // SetUnscoped always disable struct tag "deleted" func (statement *Statement) SetUnscoped() *Statement { statement.unscoped = true return statement } // GetUnscoped return true if it's unscoped func (statement *Statement) GetUnscoped() bool { return statement.unscoped } // GenIndexSQL generated create index SQL func (statement *Statement) GenIndexSQL() []string { var sqls []string tbName := statement.TableName() for _, index := range statement.RefTable.Indexes { if index.Type == schemas.IndexType { sql := statement.dialect.CreateIndexSQL(tbName, index) sqls = append(sqls, sql) } } return sqls } // GenUniqueSQL generates unique SQL func (statement *Statement) GenUniqueSQL() []string { var sqls []string tbName := statement.TableName() for _, index := range statement.RefTable.Indexes { if index.Type == schemas.UniqueType { sql := statement.dialect.CreateIndexSQL(tbName, index) sqls = append(sqls, sql) } } return sqls } // GenDelIndexSQL generate delete index SQL func (statement *Statement) GenDelIndexSQL() []string { var sqls []string tbName := statement.TableName() idx := strings.Index(tbName, ".") if idx > -1 { tbName = tbName[idx+1:] } for _, index := range statement.RefTable.Indexes { sqls = append(sqls, statement.dialect.DropIndexSQL(tbName, index)) } return sqls } func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect.Type, col *schemas.Column, allUseBool, requiredField bool) (interface{}, bool, error) { switch fieldType.Kind() { case reflect.Ptr: if fieldValue.IsNil() { return nil, true, nil } return statement.asDBCond(fieldValue.Elem(), fieldType.Elem(), col, allUseBool, requiredField) case reflect.Bool: if allUseBool || requiredField { return fieldValue.Interface(), true, nil } // if a bool in a struct, it will not be as a condition because it default is false, // please use Where() instead return nil, false, nil case reflect.String: if !requiredField && fieldValue.String() == "" { return nil, false, nil } // for MyString, should convert to string or panic if fieldType.String() != reflect.String.String() { return fieldValue.String(), true, nil } return fieldValue.Interface(), true, nil case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: if !requiredField && fieldValue.Int() == 0 { return nil, false, nil } return fieldValue.Interface(), true, nil case reflect.Float32, reflect.Float64: if !requiredField && fieldValue.Float() == 0.0 { return nil, false, nil } return fieldValue.Interface(), true, nil case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: if !requiredField && fieldValue.Uint() == 0 { return nil, false, nil } return fieldValue.Interface(), true, nil case reflect.Struct: if fieldType.ConvertibleTo(schemas.TimeType) { t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { return nil, false, nil } res, err := dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) if err != nil { return nil, false, err } return res, true, nil } else if fieldType.ConvertibleTo(schemas.BigFloatType) { t := fieldValue.Convert(schemas.BigFloatType).Interface().(big.Float) v := t.String() if v == "0" { return nil, false, nil } return t.String(), true, nil } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { return nil, false, nil } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { val, _ := valNul.Value() if val == nil && !requiredField { return nil, false, nil } return val, true, nil } else { if col.IsJSON { if col.SQLType.IsText() { bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { return nil, false, err } return string(bytes), true, nil } else if col.SQLType.IsBlob() { var bytes []byte var err error bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { return nil, false, err } return bytes, true, nil } } else { table, err := statement.tagParser.ParseWithCache(fieldValue) if err != nil { return fieldValue.Interface(), true, nil } if len(table.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) // fix non-int pk issues // if pkField.Int() != 0 { if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { return pkField.Interface(), true, nil } return nil, false, nil } return nil, false, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) } } case reflect.Array: return nil, false, nil case reflect.Slice, reflect.Map: if fieldValue == reflect.Zero(fieldType) { return nil, false, nil } if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { return nil, false, nil } if col.SQLType.IsText() { bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { return nil, false, err } return string(bytes), true, nil } else if col.SQLType.IsBlob() { var bytes []byte var err error if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && fieldType.Elem().Kind() == reflect.Uint8 { if fieldValue.Len() > 0 { return fieldValue.Bytes(), true, nil } return nil, false, nil } bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { return nil, false, err } return bytes, true, nil } return nil, false, nil } return fieldValue.Interface(), true, nil } func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool, ) (builder.Cond, error) { var conds []builder.Cond for _, col := range table.Columns() { if !includeVersion && col.IsVersion { continue } if !includeUpdated && col.IsUpdated { continue } if !includeAutoIncr && col.IsAutoIncrement { continue } if col.IsJSON { continue } var colName string if addedTableName { nm := tableName if len(aliasName) > 0 { nm = aliasName } colName = statement.quote(nm) + "." + statement.quote(col.Name) } else { colName = statement.quote(col.Name) } fieldValuePtr, err := col.ValueOf(bean) if err != nil { continue } else if fieldValuePtr == nil { continue } if col.IsDeleted && !unscoped { // tag "deleted" is enabled conds = append(conds, statement.CondDeleted(col)) } fieldValue := *fieldValuePtr if fieldValue.Interface() == nil { continue } if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { if utils.IsValueZero(fieldValue) { continue } return nil, fmt.Errorf("column %s is a TEXT type with data %#v which cannot be as compare condition", col.Name, fieldValue.Interface()) } requiredField := useAllCols if b, ok := getFlagForColumn(mustColumnMap, col); ok { if b { requiredField = true } else { continue } } fieldType := reflect.TypeOf(fieldValue.Interface()) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { conds = append(conds, builder.Eq{colName: nil}) } continue } else if !fieldValue.IsValid() { continue } else { // dereference ptr type to instance type fieldValue = fieldValue.Elem() fieldType = reflect.TypeOf(fieldValue.Interface()) requiredField = true } } val, ok, err := statement.asDBCond(fieldValue, fieldType, col, allUseBool, requiredField) if err != nil { return nil, err } if !ok { continue } conds = append(conds, builder.Eq{colName: val}) } return builder.And(conds...), nil } // BuildConds builds condition func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) } // MergeConds merge conditions from bean and id func (statement *Statement) MergeConds(bean interface{}) error { if !statement.NoAutoCondition && statement.RefTable != nil { addedTableName := (len(statement.joins) > 0) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { return err } statement.cond = statement.cond.And(autoCond) } return statement.ProcessIDParam() } func (statement *Statement) quoteColumnStr(columnStr string) string { columns := strings.Split(columnStr, ",") return statement.dialect.Quoter().Join(columns, ",") } // ConvertSQLOrArgs converts sql or args func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { sql, args, err := statement.convertSQLOrArgs(sqlOrArgs...) if err != nil { return "", nil, err } return statement.ReplaceQuote(sql), args, nil } func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { switch sqlOrArgs[0].(type) { case string: if len(sqlOrArgs) > 1 { newArgs := make([]interface{}, 0, len(sqlOrArgs)-1) for _, arg := range sqlOrArgs[1:] { if v, ok := arg.(time.Time); ok { newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) } else if v, ok := arg.(*time.Time); ok && v != nil { newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) } else if v, ok := arg.(convert.ConversionTo); ok { r, err := v.ToDB() if err != nil { return "", nil, err } if r != nil { // for nvarchar column on mssql, bytes have to be converted as ucs-2 external of driver // for binary column, a string will be converted as bytes directly. So we have to // convert bytes as string if statement.dialect.URI().DBType == schemas.MSSQL { newArgs = append(newArgs, string(r)) } else { newArgs = append(newArgs, r) } } else { newArgs = append(newArgs, nil) } } else { newArgs = append(newArgs, arg) } } return sqlOrArgs[0].(string), newArgs, nil } return sqlOrArgs[0].(string), sqlOrArgs[1:], nil case *builder.Builder: return sqlOrArgs[0].(*builder.Builder).ToSQL() case builder.Builder: bd := sqlOrArgs[0].(builder.Builder) return bd.ToSQL() } return "", nil, ErrUnSupportedType } func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string { colnames := make([]string, len(cols)) for i, col := range cols { if includeTableName { colnames[i] = statement.quote(statement.TableName()) + "." + statement.quote(col.Name) } else { colnames[i] = statement.quote(col.Name) } } return strings.Join(colnames, ", ") } // CondDeleted returns the conditions whether a record is soft deleted. func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { colName := statement.quote(col.Name) if len(statement.joins) > 0 { var prefix string if statement.TableAlias != "" { prefix = statement.TableAlias } else { prefix = statement.TableName() } colName = statement.quote(prefix) + "." + statement.quote(col.Name) } cond := builder.NewCond() if col.SQLType.IsNumeric() { cond = builder.Eq{colName: 0} } else { // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. if statement.dialect.URI().DBType != schemas.MSSQL { cond = builder.Eq{colName: utils.ZeroTime1} } } if col.Nullable { cond = cond.Or(builder.IsNull{colName}) } return cond }