diff --git a/context_cache.go b/contexts/context_cache.go similarity index 97% rename from context_cache.go rename to contexts/context_cache.go index 1bc22884..0d0f0f02 100644 --- a/context_cache.go +++ b/contexts/context_cache.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package contexts // ContextCache is the interface that operates the cache data. type ContextCache interface { diff --git a/dialects/dialect.go b/dialects/dialect.go index 26d6521a..e9e512ee 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -41,6 +41,7 @@ type Dialect interface { DBType() DBType SQLType(*schemas.Column) string FormatBytes(b []byte) string + DefaultSchema() string DriverName() string DataSourceName() string @@ -103,6 +104,10 @@ func (b *Base) SetLogger(logger log.Logger) { b.logger = logger } +func (b *Base) DefaultSchema() string { + return "" +} + func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error { b.db, b.dialect, b.uri = db, dialect, uri b.driverName, b.dataSourceName = drivername, dataSourceName diff --git a/dialects/postgres.go b/dialects/postgres.go index d6847b02..94514e95 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -788,6 +788,10 @@ func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string return nil } +func (db *postgres) DefaultSchema() string { + return PostgresPublicSchema +} + func (db *postgres) SQLType(c *schemas.Column) string { var res string switch t := c.SQLType.Name; t { diff --git a/dialects/table_name.go b/dialects/table_name.go new file mode 100644 index 00000000..a989b386 --- /dev/null +++ b/dialects/table_name.go @@ -0,0 +1,90 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "fmt" + "reflect" + "strings" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/names" +) + +// TableNameWithSchema will add schema prefix on table name if possible +func TableNameWithSchema(dialect Dialect, tableName string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if dialect.URI().Schema != "" && + dialect.URI().Schema != dialect.DefaultSchema() && + strings.Index(tableName, ".") == -1 { + return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName) + } + return tableName +} + +// TableNameNoSchema returns table name with given tableName +func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string { + quote := dialect.Quoter().Quote + switch tableName.(type) { + case []string: + t := tableName.([]string) + if len(t) > 1 { + return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1])) + } else if len(t) == 1 { + return quote(t[0]) + } + case []interface{}: + t := tableName.([]interface{}) + l := len(t) + var table string + if l > 0 { + f := t[0] + switch f.(type) { + case string: + table = f.(string) + case names.TableName: + table = f.(names.TableName).TableName() + default: + v := utils.ReflectValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + table = names.GetTableName(mapper, v) + } else { + table = quote(fmt.Sprintf("%v", f)) + } + } + } + if l > 1 { + return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1]))) + } else if l == 1 { + return quote(table) + } + case names.TableName: + return tableName.(names.TableName).TableName() + case string: + return tableName.(string) + case reflect.Value: + v := tableName.(reflect.Value) + return names.GetTableName(mapper, v) + default: + v := utils.ReflectValue(tableName) + t := v.Type() + if t.Kind() == reflect.Struct { + return names.GetTableName(mapper, v) + } + return quote(fmt.Sprintf("%v", tableName)) + } + return "" +} + +// FullTableName returns table name with quote and schema according parameter +func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string { + tbName := TableNameNoSchema(dialect, mapper, bean) + if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) { + tbName = TableNameWithSchema(dialect, tbName) + } + return tbName +} diff --git a/engine_table_test.go b/dialects/table_name_test.go similarity index 60% rename from engine_table_test.go rename to dialects/table_name_test.go index 8f2300aa..66edc2b4 100644 --- a/engine_table_test.go +++ b/dialects/table_name_test.go @@ -2,11 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package dialects import ( "testing" + "xorm.io/xorm/names" + "github.com/stretchr/testify/assert" ) @@ -20,9 +22,9 @@ func (mcc *MCC) TableName() string { return "mcc" } -func TestTableName1(t *testing.T) { - assert.NoError(t, prepareEngine()) +func TestFullTableName(t *testing.T) { + dialect := QueryDialect("mysql") - assert.EqualValues(t, "mcc", testEngine.TableName(new(MCC))) - assert.EqualValues(t, "mcc", testEngine.TableName("mcc")) + assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{})) + assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc")) } diff --git a/dialects/time.go b/dialects/time.go new file mode 100644 index 00000000..022dc960 --- /dev/null +++ b/dialects/time.go @@ -0,0 +1,49 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dialects + +import ( + "time" + + "xorm.io/xorm/schemas" +) + +// FormatTime format time as column type +func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) { + switch sqlTypeName { + case schemas.Time: + s := t.Format("2006-01-02 15:04:05") // time.RFC3339 + v = s[11:19] + case schemas.Date: + v = t.Format("2006-01-02") + case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. + v = t.Format("2006-01-02 15:04:05") + case schemas.TimeStampz: + if dialect.DBType() == schemas.MSSQL { + v = t.Format("2006-01-02T15:04:05.9999999Z07:00") + } else { + v = t.Format(time.RFC3339Nano) + } + case schemas.BigInt, schemas.Int: + v = t.Unix() + default: + v = t + } + return +} + +func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) { + if t.IsZero() { + if col.Nullable { + return nil + } + return "" + } + + if col.TimeZone != nil { + return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone)) + } + return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone)) +} diff --git a/engine.go b/engine.go index 1bf42d15..cf0126e9 100644 --- a/engine.go +++ b/engine.go @@ -18,7 +18,6 @@ import ( "strings" "time" - "xorm.io/builder" "xorm.io/xorm/caches" "xorm.io/xorm/core" "xorm.io/xorm/dialects" @@ -65,25 +64,6 @@ func (engine *Engine) BufferSize(size int) *Session { return session.BufferSize(size) } -// CondDeleted returns the conditions whether a record is soft deleted. -func (engine *Engine) CondDeleted(col *schemas.Column) builder.Cond { - var cond = builder.NewCond() - if col.SQLType.IsNumeric() { - cond = builder.Eq{col.Name: 0} - } else { - // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. - if engine.dialect.DBType() != schemas.MSSQL { - cond = builder.Eq{col.Name: utils.ZeroTime1} - } - } - - if col.Nullable { - cond = cond.Or(builder.IsNull{col.Name}) - } - - return cond -} - // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) @@ -237,7 +217,7 @@ func (engine *Engine) NoCascade() *Session { // MapCacher Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { - engine.SetCacher(engine.TableName(bean, true), cacher) + engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher) return nil } @@ -759,13 +739,13 @@ func (t *Table) IsValid() bool { } // TableInfo get table info according to bean's content -func (engine *Engine) TableInfo(bean interface{}) *Table { - v := rValue(bean) +func (engine *Engine) TableInfo(bean interface{}) (*Table, error) { + v := utils.ReflectValue(bean) tb, err := engine.tagParser.MapType(v) if err != nil { - engine.logger.Error(err) + return nil, err } - return &Table{tb, engine.TableName(bean)} + return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil } // IsTableEmpty if a table has any reocrd @@ -787,6 +767,11 @@ func (engine *Engine) IDOf(bean interface{}) schemas.PK { return engine.IDOfV(reflect.ValueOf(bean)) } +// TableName returns table name with schema prefix if has +func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { + return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...) +} + // IDOfV get id from one value of struct func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK { pk, err := engine.idOfV(rv) @@ -873,7 +858,7 @@ func (engine *Engine) CreateUniques(bean interface{}) error { // ClearCacheBean if enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { - tableName := engine.TableName(bean) + tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) cacher := engine.GetCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) @@ -885,7 +870,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { // ClearCache if enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { for _, bean := range beans { - tableName := engine.TableName(bean) + tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) cacher := engine.GetCacher(tableName) if cacher != nil { cacher.ClearIds(tableName) @@ -908,8 +893,8 @@ func (engine *Engine) Sync(beans ...interface{}) error { defer session.Close() for _, bean := range beans { - v := rValue(bean) - tableNameNoSchema := engine.TableName(bean) + v := utils.ReflectValue(bean) + tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) table, err := engine.tagParser.MapType(v) if err != nil { return err @@ -946,7 +931,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } err = session.addColumn(col.Name) @@ -957,7 +942,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } if index.Type == schemas.UniqueType { @@ -966,7 +951,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } @@ -981,7 +966,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } @@ -1250,45 +1235,11 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { if !col.DisableTimeZone && col.TimeZone != nil { tz = col.TimeZone } - return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) + return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) } func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) { - if t.IsZero() { - if col.Nullable { - return nil - } - return "" - } - - if col.TimeZone != nil { - return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone)) - } - return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ)) -} - -// formatTime format time as column type -func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { - switch sqlTypeName { - case schemas.Time: - s := t.Format("2006-01-02 15:04:05") // time.RFC3339 - v = s[11:19] - case schemas.Date: - v = t.Format("2006-01-02") - case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar. - v = t.Format("2006-01-02 15:04:05") - case schemas.TimeStampz: - if engine.dialect.DBType() == schemas.MSSQL { - v = t.Format("2006-01-02T15:04:05.9999999Z07:00") - } else { - v = t.Format(time.RFC3339Nano) - } - case schemas.BigInt, schemas.Int: - v = t.Unix() - default: - v = t - } - return + return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t) } // GetColumnMapper returns the column name mapper @@ -1332,3 +1283,7 @@ func (engine *Engine) Unscoped() *Session { session.isAutoClose = true return session.Unscoped() } + +func (engine *Engine) tbNameWithSchema(v string) string { + return dialects.TableNameWithSchema(engine.dialect, v) +} diff --git a/engine_cond.go b/engine_cond.go deleted file mode 100644 index e757df11..00000000 --- a/engine_cond.go +++ /dev/null @@ -1,234 +0,0 @@ -// Copyright 2017 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "database/sql/driver" - "fmt" - "reflect" - "strings" - "time" - - "xorm.io/builder" - "xorm.io/xorm/convert" - "xorm.io/xorm/internal/utils" - "xorm.io/xorm/schemas" -) - -func (engine *Engine) buildConds(table *schemas.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { - var conds []builder.Cond - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - - if engine.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { - continue - } - if col.SQLType.IsJson() { - continue - } - - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - if !strings.Contains(err.Error(), "is not valid") { - engine.logger.Warn(err) - } - continue - } - - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - conds = append(conds, engine.CondDeleted(col)) - } - - fieldValue := *fieldValuePtr - if fieldValue.Interface() == nil { - continue - } - - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - conds = append(conds, builder.Eq{colName: nil}) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil { - continue - } - } else { - if col.SQLType.IsJson() { - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - table, err := engine.tagParser.MapType(fieldValue) - if err != nil { - val = fieldValue.Interface() - } else { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } - } - } - case reflect.Array: - continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - conds = append(conds, builder.Eq{colName: val}) - } - - return builder.And(conds...), nil -} diff --git a/engine_table.go b/engine_table.go deleted file mode 100644 index 0954b2d3..00000000 --- a/engine_table.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2018 The Xorm Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package xorm - -import ( - "fmt" - "reflect" - "strings" - - "xorm.io/xorm/dialects" - "xorm.io/xorm/names" - "xorm.io/xorm/schemas" -) - -// tbNameWithSchema will automatically add schema prefix on table name -func (engine *Engine) tbNameWithSchema(v string) string { - // Add schema name as prefix of table name. - // Only for postgres database. - if engine.dialect.DBType() == schemas.POSTGRES && - engine.dialect.URI().Schema != "" && - engine.dialect.URI().Schema != dialects.PostgresPublicSchema && - strings.Index(v, ".") == -1 { - return engine.dialect.URI().Schema + "." + v - } - return v -} - -func isSubQuery(tbName string) bool { - const selStr = "select" - if len(tbName) <= len(selStr)+1 { - return false - } - - return strings.EqualFold(tbName[:len(selStr)], selStr) || strings.EqualFold(tbName[:len(selStr)+1], "("+selStr) -} - -// TableName returns table name with schema prefix if has -func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { - tbName := engine.tbNameNoSchema(bean) - if len(includeSchema) > 0 && includeSchema[0] && !isSubQuery(tbName) { - tbName = engine.tbNameWithSchema(tbName) - } - return tbName -} - -// tbName get some table's table name -func (session *Session) tbNameNoSchema(table *schemas.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName - } - - return table.Name -} - -func (engine *Engine) tbNameNoSchema(tablename interface{}) string { - switch tablename.(type) { - case []string: - t := tablename.([]string) - if len(t) > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) - } else if len(t) == 1 { - return engine.Quote(t[0]) - } - case []interface{}: - t := tablename.([]interface{}) - l := len(t) - var table string - if l > 0 { - f := t[0] - switch f.(type) { - case string: - table = f.(string) - case names.TableName: - table = f.(names.TableName).TableName() - default: - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.Struct { - table = names.GetTableName(engine.GetTableMapper(), v) - } else { - table = engine.Quote(fmt.Sprintf("%v", f)) - } - } - } - if l > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(table), - engine.Quote(fmt.Sprintf("%v", t[1]))) - } else if l == 1 { - return engine.Quote(table) - } - case names.TableName: - return tablename.(names.TableName).TableName() - case string: - return tablename.(string) - case reflect.Value: - v := tablename.(reflect.Value) - return names.GetTableName(engine.GetTableMapper(), v) - default: - v := rValue(tablename) - t := v.Type() - if t.Kind() == reflect.Struct { - return names.GetTableName(engine.GetTableMapper(), v) - } - return engine.Quote(fmt.Sprintf("%v", tablename)) - } - return "" -} diff --git a/error.go b/error.go index 2e9cbfaa..a223fc4a 100644 --- a/error.go +++ b/error.go @@ -26,8 +26,6 @@ var ( ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") - // ErrUnSupportedSQLType parameter of SQL is not supported - ErrUnSupportedSQLType = errors.New("Unsupported sql type") ) // ErrFieldIsNotExist columns does not exist diff --git a/helpers.go b/helpers.go index 1401cbf2..e2158c24 100644 --- a/helpers.go +++ b/helpers.go @@ -9,7 +9,6 @@ import ( "fmt" "reflect" "strconv" - "strings" "time" ) @@ -138,26 +137,6 @@ func int64ToInt(id int64, tp reflect.Type) interface{} { return int64ToIntValue(id, tp).Interface() } -func indexNoCase(s, sep string) int { - return strings.Index(strings.ToLower(s), strings.ToLower(sep)) -} - -func splitNoCase(s, sep string) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.Split(s, s[idx:idx+len(sep)]) -} - -func splitNNoCase(s, sep string, n int) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.SplitN(s, s[idx:idx+len(sep)], n) -} - func makeArray(elem string, count int) []string { res := make([]string, count) for i := 0; i < count; i++ { @@ -166,10 +145,6 @@ func makeArray(elem string, count int) []string { return res } -func rValue(bean interface{}) reflect.Value { - return reflect.Indirect(reflect.ValueOf(bean)) -} - func rType(bean interface{}) reflect.Type { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) // return reflect.TypeOf(sliceValue.Interface()) @@ -183,10 +158,6 @@ func structName(v reflect.Type) string { return v.Name() } -func indexName(tableName, idxName string) string { - return fmt.Sprintf("IDX_%v_%v", tableName, idxName) -} - func formatTime(t time.Time) string { return t.Format("2006-01-02 15:04:05") } diff --git a/interface.go b/interface.go index d7e5b778..e7894012 100644 --- a/interface.go +++ b/interface.go @@ -7,7 +7,6 @@ package xorm import ( "context" "database/sql" - "encoding/json" "reflect" "time" @@ -113,7 +112,7 @@ type EngineInterface interface { Sync(...interface{}) error Sync2(...interface{}) error StoreEngine(storeEngine string) *Session - TableInfo(bean interface{}) *Table + TableInfo(bean interface{}) (*Table, error) TableName(interface{}, ...bool) string UnMapType(reflect.Type) } @@ -123,27 +122,3 @@ var ( _ EngineInterface = &Engine{} _ EngineInterface = &EngineGroup{} ) - -// JSONInterface represents an interface to handle json data -type JSONInterface interface { - Marshal(v interface{}) ([]byte, error) - Unmarshal(data []byte, v interface{}) error -} - -var ( - // DefaultJSONHandler default json handler - DefaultJSONHandler JSONInterface = StdJSON{} -) - -// StdJSON implements JSONInterface via encoding/json -type StdJSON struct{} - -// Marshal implements JSONInterface -func (StdJSON) Marshal(v interface{}) ([]byte, error) { - return json.Marshal(v) -} - -// Unmarshal implements JSONInterface -func (StdJSON) Unmarshal(data []byte, v interface{}) error { - return json.Unmarshal(data, v) -} diff --git a/internal/json/json.go b/internal/json/json.go new file mode 100644 index 00000000..c9a2eb4e --- /dev/null +++ b/internal/json/json.go @@ -0,0 +1,31 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +import "encoding/json" + +// JSONInterface represents an interface to handle json data +type JSONInterface interface { + Marshal(v interface{}) ([]byte, error) + Unmarshal(data []byte, v interface{}) error +} + +var ( + // DefaultJSONHandler default json handler + DefaultJSONHandler JSONInterface = StdJSON{} +) + +// StdJSON implements JSONInterface via encoding/json +type StdJSON struct{} + +// Marshal implements JSONInterface +func (StdJSON) Marshal(v interface{}) ([]byte, error) { + return json.Marshal(v) +} + +// Unmarshal implements JSONInterface +func (StdJSON) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} diff --git a/internal/statements/cache.go b/internal/statements/cache.go new file mode 100644 index 00000000..d7f72318 --- /dev/null +++ b/internal/statements/cache.go @@ -0,0 +1,79 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "fmt" + "strings" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) ConvertIDSQL(sqlStr string) string { + if statement.RefTable != nil { + cols := statement.RefTable.PKColumns() + if len(cols) == 0 { + return "" + } + + colstrs := statement.joinColumns(cols, false) + sqls := utils.SplitNNoCase(sqlStr, " from ", 2) + if len(sqls) != 2 { + return "" + } + + var top string + pLimitN := statement.LimitN + if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL { + top = fmt.Sprintf("TOP %d ", *pLimitN) + } + + newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) + return newsql + } + return "" +} + +func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { + return "", "" + } + + colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) + if len(sqls) != 2 { + if len(sqls) == 1 { + return sqls[0], fmt.Sprintf("SELECT %v FROM %v", + colstrs, statement.quote(statement.TableName())) + } + return "", "" + } + + var whereStr = sqls[1] + + // TODO: for postgres only, if any other database? + var paraStr string + if statement.dialect.DBType() == schemas.POSTGRES { + paraStr = "$" + } else if statement.dialect.DBType() == schemas.MSSQL { + paraStr = ":" + } + + if paraStr != "" { + if strings.Contains(sqls[1], paraStr) { + dollers := strings.Split(sqls[1], paraStr) + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) + } + } + } + + return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", + colstrs, statement.quote(statement.TableName()), + whereStr) +} diff --git a/statement_columnmap.go b/internal/statements/column_map.go similarity index 52% rename from statement_columnmap.go rename to internal/statements/column_map.go index b6523b1e..8440f821 100644 --- a/statement_columnmap.go +++ b/internal/statements/column_map.go @@ -2,13 +2,17 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements -import "strings" +import ( + "strings" + + "xorm.io/xorm/schemas" +) type columnMap []string -func (m columnMap) contain(colName string) bool { +func (m columnMap) Contain(colName string) bool { if len(m) == 0 { return false } @@ -27,9 +31,28 @@ func (m columnMap) contain(colName string) bool { } func (m *columnMap) add(colName string) bool { - if m.contain(colName) { + if m.Contain(colName) { return false } *m = append(*m, colName) return true } + +func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) { + if len(m) == 0 { + return false, false + } + + n := len(col.Name) + + for mk := range m { + if len(mk) != n { + continue + } + if strings.EqualFold(mk, col.Name) { + return m[mk], true + } + } + + return false, false +} diff --git a/statement_exprparam.go b/internal/statements/expr_param.go similarity index 76% rename from statement_exprparam.go rename to internal/statements/expr_param.go index 3231f86a..6657408e 100644 --- a/statement_exprparam.go +++ b/internal/statements/expr_param.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "fmt" @@ -26,21 +26,21 @@ type exprParam struct { } type exprParams struct { - colNames []string - args []interface{} + ColNames []string + Args []interface{} } func (exprs *exprParams) Len() int { - return len(exprs.colNames) + return len(exprs.ColNames) } func (exprs *exprParams) addParam(colName string, arg interface{}) { - exprs.colNames = append(exprs.colNames, colName) - exprs.args = append(exprs.args, arg) + exprs.ColNames = append(exprs.ColNames, colName) + exprs.Args = append(exprs.Args, arg) } -func (exprs *exprParams) isColExist(colName string) bool { - for _, name := range exprs.colNames { +func (exprs *exprParams) IsColExist(colName string) bool { + for _, name := range exprs.ColNames { if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { return true } @@ -49,16 +49,16 @@ func (exprs *exprParams) isColExist(colName string) bool { } func (exprs *exprParams) getByName(colName string) (exprParam, bool) { - for i, name := range exprs.colNames { + for i, name := range exprs.ColNames { if strings.EqualFold(name, colName) { - return exprParam{name, exprs.args[i]}, true + return exprParam{name, exprs.Args[i]}, true } } return exprParam{}, false } -func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { - for i, expr := range exprs.args { +func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error { + for i, expr := range exprs.Args { switch arg := expr.(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { @@ -83,7 +83,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { } w.Append(arg) } - if i != len(exprs.args)-1 { + if i != len(exprs.Args)-1 { if _, err := w.WriteString(","); err != nil { return err } @@ -93,7 +93,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { } func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { - for i, colName := range exprs.colNames { + for i, colName := range exprs.ColNames { if _, err := w.WriteString(colName); err != nil { return err } @@ -101,7 +101,7 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } - switch arg := exprs.args[i].(type) { + switch arg := exprs.Args[i].(type) { case *builder.Builder: if _, err := w.WriteString("("); err != nil { return err @@ -113,10 +113,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { return err } default: - w.Append(exprs.args[i]) + w.Append(exprs.Args[i]) } - if i+1 != len(exprs.colNames) { + if i+1 != len(exprs.ColNames) { if _, err := w.WriteString(","); err != nil { return err } diff --git a/internal/statements/query.go b/internal/statements/query.go new file mode 100644 index 00000000..1519cb08 --- /dev/null +++ b/internal/statements/query.go @@ -0,0 +1,448 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "xorm.io/builder" + "xorm.io/xorm/schemas" +) + +func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { + if len(sqlOrArgs) > 0 { + return ConvertSQLOrArgs(sqlOrArgs...) + } + + if statement.RawSQL != "" { + return statement.RawSQL, statement.RawParams, nil + } + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + if statement.JoinStr == "" { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + if err := statement.ProcessIDParam(); err != nil { + return "", nil, err + } + + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err + } + + args := append(statement.joinArgs, condArgs...) + sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) + if err != nil { + return "", nil, err + } + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + + return sqlStr, args, nil +} + +func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.RawSQL, statement.RawParams, nil + } + + statement.SetRefBean(bean) + + var sumStrs = make([]string, 0, len(columns)) + for _, colName := range columns { + if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { + colName = statement.quote(colName) + } + sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) + } + sumSelect := strings.Join(sumStrs, ", ") + + condSQL, condArgs, err := statement.GenConds(bean) + if err != nil { + return "", nil, err + } + + sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { + v := rValue(bean) + isStruct := v.Kind() == reflect.Struct + if isStruct { + statement.SetRefBean(bean) + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + // TODO: always generate column names, not use * even if join + if len(statement.JoinStr) == 0 { + if len(columnStr) == 0 { + if len(statement.GroupByStr) > 0 { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if len(columnStr) == 0 { + if len(statement.GroupByStr) > 0 { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } + } + } + } + + if len(columnStr) == 0 { + columnStr = "*" + } + + if isStruct { + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } + } else { + if err := statement.ProcessIDParam(); err != nil { + return "", nil, err + } + } + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err + } + + sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.RawSQL, statement.RawParams, nil + } + + var condSQL string + var condArgs []interface{} + var err error + if len(beans) > 0 { + statement.SetRefBean(beans[0]) + condSQL, condArgs, err = statement.GenConds(beans[0]) + } else { + condSQL, condArgs, err = builder.ToSQL(statement.cond) + } + if err != nil { + return "", nil, err + } + + var selectSQL = statement.SelectStr + if len(selectSQL) <= 0 { + if statement.IsDistinct { + selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr()) + } else { + selectSQL = "count(*)" + } + } + sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil +} + +func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { + var ( + distinct string + dialect = statement.dialect + quote = statement.quote + fromStr = " FROM " + top, mssqlCondi, whereStr string + ) + if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { + distinct = "DISTINCT " + } + if len(condSQL) > 0 { + whereStr = " WHERE " + condSQL + } + + if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { + fromStr += statement.TableName() + } else { + fromStr += quote(statement.TableName()) + } + + if statement.TableAlias != "" { + if dialect.DBType() == schemas.ORACLE { + fromStr += " " + quote(statement.TableAlias) + } else { + fromStr += " AS " + quote(statement.TableAlias) + } + } + if statement.JoinStr != "" { + fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) + } + + pLimitN := statement.LimitN + if dialect.DBType() == schemas.MSSQL { + if pLimitN != nil { + LimitNValue := *pLimitN + top = fmt.Sprintf("TOP %d ", LimitNValue) + } + if statement.Start > 0 { + var column string + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } + } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } else { + column = statement.RefTable.PKColumns()[0].Name + } + if statement.needTableName() { + if len(statement.TableAlias) > 0 { + column = statement.TableAlias + "." + column + } else { + column = statement.TableName() + "." + column + } + } + + var orderStr string + if needOrderBy && len(statement.OrderStr) > 0 { + orderStr = " ORDER BY " + statement.OrderStr + } + + var groupStr string + if len(statement.GroupByStr) > 0 { + groupStr = " GROUP BY " + statement.GroupByStr + } + mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", + column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) + } + } + + var buf strings.Builder + fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) + if len(mssqlCondi) > 0 { + if len(whereStr) > 0 { + fmt.Fprint(&buf, " AND ", mssqlCondi) + } else { + fmt.Fprint(&buf, " WHERE ", mssqlCondi) + } + } + + if statement.GroupByStr != "" { + fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) + } + if statement.HavingStr != "" { + fmt.Fprint(&buf, " ", statement.HavingStr) + } + if needOrderBy && statement.OrderStr != "" { + fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) + } + if needLimit { + if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE { + if statement.Start > 0 { + if pLimitN != nil { + fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) + } else { + fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) + } + } else if pLimitN != nil { + fmt.Fprint(&buf, " LIMIT ", *pLimitN) + } + } else if dialect.DBType() == schemas.ORACLE { + if statement.Start != 0 || pLimitN != nil { + oldString := buf.String() + buf.Reset() + rawColStr := columnStr + if rawColStr == "*" { + rawColStr = "at.*" + } + fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) + } + } + } + if statement.IsForUpdate { + return dialect.ForUpdateSQL(buf.String()), nil + } + + return buf.String(), nil +} + +func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.RawSQL, statement.RawParams, nil + } + + var sqlStr string + var args []interface{} + var joinStr string + var err error + if len(bean) == 0 { + tableName := statement.TableName() + if len(tableName) <= 0 { + return "", nil, ErrTableNotFound + } + + tableName = statement.quote(tableName) + if len(statement.JoinStr) > 0 { + joinStr = statement.JoinStr + } + + if statement.Conds().IsValid() { + condSQL, condArgs, err := builder.ToSQL(statement.Conds()) + if err != nil { + return "", nil, err + } + + if statement.dialect.DBType() == schemas.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) + } else if statement.dialect.DBType() == schemas.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) + } + args = condArgs + } else { + if statement.dialect.DBType() == schemas.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) + } else if statement.dialect.DBType() == schemas.ORACLE { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) + } + args = []interface{}{} + } + } else { + beanValue := reflect.ValueOf(bean[0]) + if beanValue.Kind() != reflect.Ptr { + return "", nil, errors.New("needs a pointer") + } + + if beanValue.Elem().Kind() == reflect.Struct { + if err := statement.SetRefBean(bean[0]); err != nil { + return "", nil, err + } + } + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + statement.Limit(1) + sqlStr, args, err = statement.GenGetSQL(bean[0]) + if err != nil { + return "", nil, err + } + } + + return sqlStr, args, nil +} + +func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { + if statement.RawSQL != "" { + return statement.RawSQL, statement.RawParams, nil + } + + var sqlStr string + var args []interface{} + var err error + + if len(statement.TableName()) <= 0 { + return "", nil, ErrTableNotFound + } + + var columnStr = statement.ColumnStr() + if len(statement.SelectStr) > 0 { + columnStr = statement.SelectStr + } else { + if statement.JoinStr == "" { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if statement.GroupByStr != "" { + columnStr = statement.quoteColumnStr(statement.GroupByStr) + } else { + columnStr = "*" + } + } + } + if columnStr == "" { + columnStr = "*" + } + } + + statement.cond = statement.cond.And(autoCond) + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err + } + + args = append(statement.joinArgs, condArgs...) + sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true) + if err != nil { + return "", nil, err + } + // for mssql and use limit + qs := strings.Count(sqlStr, "?") + if len(args)*2 == qs { + args = append(args, args...) + } + + return sqlStr, args, nil +} diff --git a/statement.go b/internal/statements/statement.go similarity index 55% rename from statement.go rename to internal/statements/statement.go index 3a823d82..92b1809a 100644 --- a/statement.go +++ b/internal/statements/statement.go @@ -2,27 +2,44 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "database/sql/driver" + "errors" "fmt" "reflect" "strings" "time" "xorm.io/builder" + "xorm.io/xorm/contexts" "xorm.io/xorm/convert" "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" + "xorm.io/xorm/tags" +) + +var ( + // ErrConditionType condition type unsupported + ErrConditionType = errors.New("Unsupported condition type") + // ErrUnSupportedSQLType parameter of SQL is not supported + ErrUnSupportedSQLType = errors.New("Unsupported sql type") + // ErrUnSupportedType unsupported error + ErrUnSupportedType = errors.New("Unsupported type error") + // ErrTableNotFound table not found error + ErrTableNotFound = errors.New("Table not found") ) // Statement save all the sql info for executing SQL type Statement struct { - RefTable *schemas.Table - dialect dialects.Dialect - Engine *Engine + RefTable *schemas.Table + dialect dialects.Dialect + //Engine *Engine + defaultTimeZone *time.Location + tagParser *tags.Parser Start int LimitN *int idParam *schemas.PK @@ -31,7 +48,7 @@ type Statement struct { joinArgs []interface{} GroupByStr string HavingStr string - selectStr string + SelectStr string useAllCols bool AltTableName string tableName string @@ -43,36 +60,47 @@ type Statement struct { Charset string UseCache bool UseAutoTime bool - noAutoCondition bool + NoAutoCondition bool IsDistinct bool IsForUpdate bool TableAlias string allUseBool bool - checkVersion bool + CheckVersion bool unscoped bool - columnMap columnMap - omitColumnMap columnMap - mustColumnMap map[string]bool - nullableMap map[string]bool - incrColumns exprParams - decrColumns exprParams - exprColumns exprParams + ColumnMap columnMap + OmitColumnMap columnMap + MustColumnMap map[string]bool + NullableMap map[string]bool + IncrColumns exprParams + DecrColumns exprParams + ExprColumns exprParams cond builder.Cond - bufferSize int - context ContextCache - lastError error + BufferSize int + Context contexts.ContextCache + LastError error } -func newStatement(dialect dialects.Dialect) *Statement { +// NewStatement creates a new statement +func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZone *time.Location) *Statement { statement := &Statement{ - dialect: dialect, + dialect: dialect, + tagParser: tagParser, + defaultTimeZone: defaultTimeZone, } statement.Reset() return statement } +func (statement *Statement) SetTableName(tableName string) { + statement.tableName = tableName +} + func (statement *Statement) omitStr() string { - return statement.dialect.Quoter().Join(statement.omitColumnMap, " ,") + return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") +} + +func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { + statement.Context = ctxCache } // Init reset all the statement's fields @@ -86,8 +114,8 @@ func (statement *Statement) Reset() { statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" statement.HavingStr = "" - statement.columnMap = columnMap{} - statement.omitColumnMap = columnMap{} + statement.ColumnMap = columnMap{} + statement.OmitColumnMap = columnMap{} statement.AltTableName = "" statement.tableName = "" statement.idParam = nil @@ -95,31 +123,31 @@ func (statement *Statement) Reset() { statement.RawParams = make([]interface{}, 0) statement.UseCache = true statement.UseAutoTime = true - statement.noAutoCondition = false + statement.NoAutoCondition = false statement.IsDistinct = false statement.IsForUpdate = false statement.TableAlias = "" - statement.selectStr = "" + statement.SelectStr = "" statement.allUseBool = false statement.useAllCols = false - statement.mustColumnMap = make(map[string]bool) - statement.nullableMap = make(map[string]bool) - statement.checkVersion = true + statement.MustColumnMap = make(map[string]bool) + statement.NullableMap = make(map[string]bool) + statement.CheckVersion = true statement.unscoped = false - statement.incrColumns = exprParams{} - statement.decrColumns = exprParams{} - statement.exprColumns = exprParams{} + statement.IncrColumns = exprParams{} + statement.DecrColumns = exprParams{} + statement.ExprColumns = exprParams{} statement.cond = builder.NewCond() - statement.bufferSize = 0 - statement.context = nil - statement.lastError = nil + statement.BufferSize = 0 + statement.Context = nil + statement.LastError = nil } // NoAutoCondition if you do not want convert bean's field as query condition, then use this function -func (statement *Statement) NoAutoCondition(no ...bool) *Statement { - statement.noAutoCondition = true +func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { + statement.NoAutoCondition = true if len(no) > 0 { - statement.noAutoCondition = no[0] + statement.NoAutoCondition = no[0] } return statement } @@ -137,13 +165,13 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme var err error statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() if err != nil { - statement.lastError = err + statement.LastError = err } case string: statement.RawSQL = query.(string) statement.RawParams = args default: - statement.lastError = ErrUnSupportedSQLType + statement.LastError = ErrUnSupportedSQLType } return statement @@ -180,7 +208,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme } } default: - statement.lastError = ErrConditionType + statement.LastError = ErrConditionType } return statement @@ -223,291 +251,30 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement return statement } -func (statement *Statement) setRefValue(v reflect.Value) error { +func (statement *Statement) SetRefValue(v reflect.Value) error { var err error - statement.RefTable, err = statement.Engine.tagParser.MapType(reflect.Indirect(v)) + statement.RefTable, err = statement.tagParser.MapType(reflect.Indirect(v)) if err != nil { return err } - statement.tableName = statement.Engine.TableName(v, true) + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, v, true) return nil } -func (statement *Statement) setRefBean(bean interface{}) error { +func rValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} + +func (statement *Statement) SetRefBean(bean interface{}) error { var err error - statement.RefTable, err = statement.Engine.tagParser.MapType(rValue(bean)) + statement.RefTable, err = statement.tagParser.MapType(rValue(bean)) if err != nil { return err } - statement.tableName = statement.Engine.TableName(bean, true) + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, bean, true) return nil } -// Auto generating update columnes and values according a struct -func (statement *Statement) buildUpdates(bean interface{}, - includeVersion, includeUpdated, includeNil, - includeAutoIncr, update bool) ([]string, []interface{}) { - engine := statement.Engine - table := statement.RefTable - allUseBool := statement.allUseBool - useAllCols := statement.useAllCols - mustColumnMap := statement.mustColumnMap - nullableMap := statement.nullableMap - columnMap := statement.columnMap - omitColumnMap := statement.omitColumnMap - unscoped := statement.unscoped - - var colNames = make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if col.IsCreated && !columnMap.contain(col.Name) { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - if col.IsDeleted && !unscoped { - continue - } - if omitColumnMap.contain(col.Name) { - continue - } - if len(columnMap) > 0 && !columnMap.contain(col.Name) { - continue - } - - if col.MapType == schemas.ONLYFROMDB { - continue - } - - if statement.incrColumns.isColExist(col.Name) { - continue - } else if statement.decrColumns.isColExist(col.Name) { - continue - } else if statement.exprColumns.isColExist(col.Name) { - continue - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - engine.logger.Error(err) - continue - } - - fieldValue := *fieldValuePtr - fieldType := reflect.TypeOf(fieldValue.Interface()) - if fieldType == nil { - continue - } - - requiredField := useAllCols - includeNil := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if b, ok := getFlagForColumn(nullableMap, col); ok { - if b && col.Nullable && utils.IsZero(fieldValue.Interface()) { - var nilValue *int - fieldValue = reflect.ValueOf(nilValue) - fieldType = reflect.TypeOf(fieldValue.Interface()) - includeNil = true - } - } - - var val interface{} - - if fieldValue.CanAddr() { - if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { - data, err := structConvert.ToDB() - if err != nil { - engine.logger.Error(err) - } else { - val = data - } - goto APPEND - } - } - - if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { - data, err := structConvert.ToDB() - if err != nil { - engine.logger.Error(err) - } else { - val = data - } - goto APPEND - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name))) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = nulType.Value() - } else { - if !col.SQLType.IsJson() { - table, err := engine.tagParser.MapType(fieldValue) - if err != nil { - val = fieldValue.Interface() - } else { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - if pkField.IsValid() && (!requiredField && !utils.IsZero(pkField.Interface())) { - val = pkField.Interface() - } else { - continue - } - } else { - // TODO: how to handler? - panic("not supported") - } - } - } else { - // Blank struct could not be as update data - if requiredField || !utils.IsStructZero(fieldValue) { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) - } - if col.SQLType.IsText() { - val = string(bytes) - } else if col.SQLType.IsBlob() { - val = bytes - } - } else { - continue - } - } - } - case reflect.Array, reflect.Slice, reflect.Map: - if !requiredField { - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldType.Kind() == reflect.Array { - if utils.IsArrayZero(fieldValue) { - continue - } - } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - } - - if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if fieldType.Kind() == reflect.Slice && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else if fieldType.Kind() == reflect.Array && - fieldType.Elem().Kind() == reflect.Uint8 { - val = fieldValue.Slice(0, 0).Interface() - } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - APPEND: - args = append(args, val) - if col.IsPrimaryKey { - continue - } - colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name))) - } - - return colNames, args -} - func (statement *Statement) needTableName() bool { return len(statement.JoinStr) > 0 } @@ -563,9 +330,9 @@ func (statement *Statement) ID(id interface{}) *Statement { // Incr Generate "Update ... Set column = column + arg" statement func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { - statement.incrColumns.addParam(column, arg[0]) + statement.IncrColumns.addParam(column, arg[0]) } else { - statement.incrColumns.addParam(column, 1) + statement.IncrColumns.addParam(column, 1) } return statement } @@ -573,16 +340,16 @@ func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { // Decr Generate "Update ... Set column = column - arg" statement func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { if len(arg) > 0 { - statement.decrColumns.addParam(column, arg[0]) + statement.DecrColumns.addParam(column, arg[0]) } else { - statement.decrColumns.addParam(column, 1) + statement.DecrColumns.addParam(column, 1) } return statement } // SetExpr Generate "Update ... Set column = {expression}" statement func (statement *Statement) SetExpr(column string, expression interface{}) *Statement { - statement.exprColumns.addParam(column, expression) + statement.ExprColumns.addParam(column, expression) return statement } @@ -601,21 +368,34 @@ func (statement *Statement) ForUpdate() *Statement { // Select replace select func (statement *Statement) Select(str string) *Statement { - statement.selectStr = str + statement.SelectStr = str return statement } +func col2NewCols(columns ...string) []string { + newColumns := make([]string, 0, len(columns)) + for _, col := range columns { + col = strings.Replace(col, "`", "", -1) + col = strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns +} + // Cols generate "col1, col2" statement func (statement *Statement) Cols(columns ...string) *Statement { cols := col2NewCols(columns...) for _, nc := range cols { - statement.columnMap.add(nc) + statement.ColumnMap.add(nc) } return statement } -func (statement *Statement) columnStr() string { - return statement.dialect.Quoter().Join(statement.columnMap, ", ") +func (statement *Statement) ColumnStr() string { + return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") } // AllCols update use only: update all columns @@ -628,7 +408,7 @@ func (statement *Statement) AllCols() *Statement { func (statement *Statement) MustCols(columns ...string) *Statement { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.mustColumnMap[strings.ToLower(nc)] = true + statement.MustColumnMap[strings.ToLower(nc)] = true } return statement } @@ -647,7 +427,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement { func (statement *Statement) Omit(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.omitColumnMap = append(statement.omitColumnMap, nc) + statement.OmitColumnMap = append(statement.OmitColumnMap, nc) } } @@ -655,7 +435,7 @@ func (statement *Statement) Omit(columns ...string) { func (statement *Statement) Nullable(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { - statement.nullableMap[strings.ToLower(nc)] = true + statement.NullableMap[strings.ToLower(nc)] = true } } @@ -717,21 +497,24 @@ func (statement *Statement) Asc(colNames ...string) *Statement { return statement } +func (statement *Statement) Conds() builder.Cond { + return statement.cond +} + // Table tempororily set table name, the parameter could be a string or a pointer of struct -func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { +func (statement *Statement) SetTable(tableNameOrBean interface{}) error { v := rValue(tableNameOrBean) t := v.Type() if t.Kind() == reflect.Struct { var err error - statement.RefTable, err = statement.Engine.tagParser.MapType(v) + statement.RefTable, err = statement.tagParser.MapType(v) if err != nil { - statement.Engine.logger.Error(err) - return statement + return err } } - statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) - return statement + statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tableNameOrBean, true) + return nil } // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN @@ -747,7 +530,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition case builder.Builder: subSQL, subQueryArgs, err := tp.ToSQL() if err != nil { - statement.lastError = err + statement.LastError = err return statement } @@ -760,7 +543,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition case *builder.Builder: subSQL, subQueryArgs, err := tp.ToSQL() if err != nil { - statement.lastError = err + statement.LastError = err return statement } @@ -771,8 +554,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: - tbName := statement.Engine.TableName(tablename, true) - if !isSubQuery(tbName) { + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tablename, true) + if !utils.IsSubQuery(tbName) { var buf strings.Builder statement.dialect.Quoter().QuoteTo(&buf, tbName) tbName = buf.String() @@ -785,6 +568,15 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition return statement } +// tbName get some table's table name +func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { + if len(statement.AltTableName) > 0 { + return statement.AltTableName + } + + return table.Name +} + // GroupBy generate "Group By keys" statement func (statement *Statement) GroupBy(keys string) *Statement { statement.GroupByStr = keys @@ -798,11 +590,15 @@ func (statement *Statement) Having(conditions string) *Statement { } // Unscoped always disable struct tag "deleted" -func (statement *Statement) Unscoped() *Statement { +func (statement *Statement) SetUnscoped() *Statement { statement.unscoped = true return statement } +func (statement *Statement) GetUnscoped() bool { + return statement.unscoped +} + func (statement *Statement) genColumnStr() string { if statement.RefTable == nil { return "" @@ -812,11 +608,11 @@ func (statement *Statement) genColumnStr() string { columns := statement.RefTable.Columns() for _, col := range columns { - if statement.omitColumnMap.contain(col.Name) { + if statement.OmitColumnMap.Contain(col.Name) { continue } - if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) { + if len(statement.ColumnMap) > 0 && !statement.ColumnMap.Contain(col.Name) { continue } @@ -844,12 +640,12 @@ func (statement *Statement) genColumnStr() string { return buf.String() } -func (statement *Statement) genCreateTableSQL() string { +func (statement *Statement) GenCreateTableSQL() string { return statement.dialect.CreateTableSQL(statement.RefTable, statement.TableName(), statement.StoreEngine, statement.Charset) } -func (statement *Statement) genIndexSQL() []string { +func (statement *Statement) GenIndexSQL() []string { var sqls []string tbName := statement.TableName() for _, index := range statement.RefTable.Indexes { @@ -865,7 +661,7 @@ func uniqueName(tableName, uqeName string) string { return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) } -func (statement *Statement) genUniqueSQL() []string { +func (statement *Statement) GenUniqueSQL() []string { var sqls []string tbName := statement.TableName() for _, index := range statement.RefTable.Indexes { @@ -877,7 +673,7 @@ func (statement *Statement) genUniqueSQL() []string { return sqls } -func (statement *Statement) genDelIndexSQL() []string { +func (statement *Statement) GenDelIndexSQL() []string { var sqls []string tbName := statement.TableName() idx := strings.Index(tbName, ".") @@ -891,9 +687,9 @@ func (statement *Statement) genDelIndexSQL() []string { if index.Type == schemas.UniqueType { rIdxName = uniqueName(idxPrefixName, idxName) } else if index.Type == schemas.IndexType { - rIdxName = indexName(idxPrefixName, idxName) + rIdxName = utils.IndexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.quote(statement.Engine.TableName(rIdxName, true))) + sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, rIdxName, true))) if statement.dialect.IndexOnTable() { sql += fmt.Sprintf(" ON %v", statement.quote(tbName)) } @@ -902,28 +698,240 @@ func (statement *Statement) genDelIndexSQL() []string { return sqls } -func (statement *Statement) buildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) +func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + var conds []builder.Cond + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + + if statement.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) { + continue + } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = statement.quote(nm) + "." + statement.quote(col.Name) + } else { + colName = statement.quote(col.Name) + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + if !strings.Contains(err.Error(), "is not valid") { + //engine.logger.Warn(err) + } + continue + } + + if col.IsDeleted && !unscoped { // tag "deleted" is enabled + conds = append(conds, statement.CondDeleted(col)) + } + + fieldValue := *fieldValuePtr + if fieldValue.Interface() == nil { + continue + } + + fieldType := reflect.TypeOf(fieldValue.Interface()) + requiredField := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + conds = append(conds, builder.Eq{colName: nil}) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil { + continue + } + } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = bytes + } + } else { + table, err := statement.tagParser.MapType(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + } + } + case reflect.Array: + continue + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, err + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + conds = append(conds, builder.Eq{colName: val}) + } + + return builder.And(conds...), nil +} + +func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { + return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) } func (statement *Statement) mergeConds(bean interface{}) error { - if !statement.noAutoCondition { + if !statement.NoAutoCondition { var addedTableName = (len(statement.JoinStr) > 0) - autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) + autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { return err } statement.cond = statement.cond.And(autoCond) } - if err := statement.processIDParam(); err != nil { + if err := statement.ProcessIDParam(); err != nil { return err } return nil } -func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { +func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) { if err := statement.mergeConds(bean); err != nil { return "", nil, err } @@ -936,242 +944,21 @@ func (statement *Statement) quoteColumnStr(columnStr string) string { return statement.dialect.Quoter().Join(columns, ",") } -func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { - v := rValue(bean) - isStruct := v.Kind() == reflect.Struct - if isStruct { - statement.setRefBean(bean) +func ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { + switch sqlOrArgs[0].(type) { + case string: + return sqlOrArgs[0].(string), sqlOrArgs[1:], nil + case *builder.Builder: + return sqlOrArgs[0].(*builder.Builder).ToSQL() + case builder.Builder: + bd := sqlOrArgs[0].(builder.Builder) + return bd.ToSQL() } - var columnStr = statement.columnStr() - if len(statement.selectStr) > 0 { - columnStr = statement.selectStr - } else { - // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.quoteColumnStr(statement.GroupByStr) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.quoteColumnStr(statement.GroupByStr) - } - } - } - } - - if len(columnStr) == 0 { - columnStr = "*" - } - - if isStruct { - if err := statement.mergeConds(bean); err != nil { - return "", nil, err - } - } else { - if err := statement.processIDParam(); err != nil { - return "", nil, err - } - } - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - - sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil + return "", nil, ErrUnSupportedType } -func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) { - var condSQL string - var condArgs []interface{} - var err error - if len(beans) > 0 { - statement.setRefBean(beans[0]) - condSQL, condArgs, err = statement.genConds(beans[0]) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) - } - if err != nil { - return "", nil, err - } - - var selectSQL = statement.selectStr - if len(selectSQL) <= 0 { - if statement.IsDistinct { - selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.columnStr()) - } else { - selectSQL = "count(*)" - } - } - sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { - statement.setRefBean(bean) - - var sumStrs = make([]string, 0, len(columns)) - for _, colName := range columns { - if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") { - colName = statement.quote(colName) - } - sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) - } - sumSelect := strings.Join(sumStrs, ", ") - - condSQL, condArgs, err := statement.genConds(bean) - if err != nil { - return "", nil, err - } - - sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil -} - -func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { - var ( - distinct string - dialect = statement.dialect - quote = statement.quote - fromStr = " FROM " - top, mssqlCondi, whereStr string - ) - if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { - distinct = "DISTINCT " - } - if len(condSQL) > 0 { - whereStr = " WHERE " + condSQL - } - - if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { - fromStr += statement.TableName() - } else { - fromStr += quote(statement.TableName()) - } - - if statement.TableAlias != "" { - if dialect.DBType() == schemas.ORACLE { - fromStr += " " + quote(statement.TableAlias) - } else { - fromStr += " AS " + quote(statement.TableAlias) - } - } - if statement.JoinStr != "" { - fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) - } - - pLimitN := statement.LimitN - if dialect.DBType() == schemas.MSSQL { - if pLimitN != nil { - LimitNValue := *pLimitN - top = fmt.Sprintf("TOP %d ", LimitNValue) - } - if statement.Start > 0 { - var column string - if len(statement.RefTable.PKColumns()) == 0 { - for _, index := range statement.RefTable.Indexes { - if len(index.Cols) == 1 { - column = index.Cols[0] - break - } - } - if len(column) == 0 { - column = statement.RefTable.ColumnsSeq()[0] - } - } else { - column = statement.RefTable.PKColumns()[0].Name - } - if statement.needTableName() { - if len(statement.TableAlias) > 0 { - column = statement.TableAlias + "." + column - } else { - column = statement.TableName() + "." + column - } - } - - var orderStr string - if needOrderBy && len(statement.OrderStr) > 0 { - orderStr = " ORDER BY " + statement.OrderStr - } - - var groupStr string - if len(statement.GroupByStr) > 0 { - groupStr = " GROUP BY " + statement.GroupByStr - } - mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", - column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) - } - } - - var buf strings.Builder - fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) - if len(mssqlCondi) > 0 { - if len(whereStr) > 0 { - fmt.Fprint(&buf, " AND ", mssqlCondi) - } else { - fmt.Fprint(&buf, " WHERE ", mssqlCondi) - } - } - - if statement.GroupByStr != "" { - fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) - } - if statement.HavingStr != "" { - fmt.Fprint(&buf, " ", statement.HavingStr) - } - if needOrderBy && statement.OrderStr != "" { - fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) - } - if needLimit { - if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE { - if statement.Start > 0 { - if pLimitN != nil { - fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) - } else { - fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start) - } - } else if pLimitN != nil { - fmt.Fprint(&buf, " LIMIT ", *pLimitN) - } - } else if dialect.DBType() == schemas.ORACLE { - if statement.Start != 0 || pLimitN != nil { - oldString := buf.String() - buf.Reset() - rawColStr := columnStr - if rawColStr == "*" { - rawColStr = "at.*" - } - fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", - columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) - } - } - } - if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), nil - } - - return buf.String(), nil -} - -func (statement *Statement) processIDParam() error { +func (statement *Statement) ProcessIDParam() error { if statement.idParam == nil || statement.RefTable == nil { return nil } @@ -1203,68 +990,21 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName return strings.Join(colnames, ", ") } -func (statement *Statement) convertIDSQL(sqlStr string) string { - if statement.RefTable != nil { - cols := statement.RefTable.PKColumns() - if len(cols) == 0 { - return "" +// CondDeleted returns the conditions whether a record is soft deleted. +func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { + var cond = builder.NewCond() + if col.SQLType.IsNumeric() { + cond = builder.Eq{col.Name: 0} + } else { + // FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value. + if statement.dialect.DBType() != schemas.MSSQL { + cond = builder.Eq{col.Name: utils.ZeroTime1} } - - colstrs := statement.joinColumns(cols, false) - sqls := splitNNoCase(sqlStr, " from ", 2) - if len(sqls) != 2 { - return "" - } - - var top string - pLimitN := statement.LimitN - if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL { - top = fmt.Sprintf("TOP %d ", *pLimitN) - } - - newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) - return newsql } - return "" -} - -func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { - if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { - return "", "" - } - - colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) - sqls := splitNNoCase(sqlStr, "where", 2) - if len(sqls) != 2 { - if len(sqls) == 1 { - return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.quote(statement.TableName())) - } - return "", "" - } - - var whereStr = sqls[1] - - // TODO: for postgres only, if any other database? - var paraStr string - if statement.dialect.DBType() == schemas.POSTGRES { - paraStr = "$" - } else if statement.dialect.DBType() == schemas.MSSQL { - paraStr = ":" - } - - if paraStr != "" { - if strings.Contains(sqls[1], paraStr) { - dollers := strings.Split(sqls[1], paraStr) - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) - } - } - } - - return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.quote(statement.TableName()), - whereStr) + + if col.Nullable { + cond = cond.Or(builder.IsNull{col.Name}) + } + + return cond } diff --git a/statement_args.go b/internal/statements/statement_args.go similarity index 78% rename from statement_args.go rename to internal/statements/statement_args.go index 22bfeb7b..8eee246e 100644 --- a/statement_args.go +++ b/internal/statements/statement_args.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "fmt" @@ -77,7 +77,7 @@ func convertArg(arg interface{}, convertFunc func(string) string) string { const insertSelectPlaceHolder = true -func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { +func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case bool: if statement.dialect.DBType() == schemas.MSSQL { @@ -130,9 +130,9 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er return nil } -func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}) error { +func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { for i, arg := range args { - if err := statement.writeArg(w, arg); err != nil { + if err := statement.WriteArg(w, arg); err != nil { return err } @@ -144,27 +144,3 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{} } return nil } - -func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error { - for i, colName := range cols { - if len(leftQuote) > 0 && colName[0] != '`' { - if _, err := w.WriteString(leftQuote); err != nil { - return err - } - } - if _, err := w.WriteString(colName); err != nil { - return err - } - if len(rightQuote) > 0 && colName[len(colName)-1] != '`' { - if _, err := w.WriteString(rightQuote); err != nil { - return err - } - } - if i+1 != len(cols) { - if _, err := w.WriteString(","); err != nil { - return err - } - } - } - return nil -} diff --git a/internal/statements/statement_test.go b/internal/statements/statement_test.go new file mode 100644 index 00000000..3b6e3ae2 --- /dev/null +++ b/internal/statements/statement_test.go @@ -0,0 +1,184 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "reflect" + "strings" + "testing" + + "xorm.io/xorm/schemas" +) + +var colStrTests = []struct { + omitColumn string + onlyToDBColumnNdx int + expected string +}{ + {"", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, + {"Code2", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, + {"", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, + {"Code3", 1, "`ID`, `Caption`, `Code1`, `Code2`, `ParentID`, `Latitude`, `Longitude`"}, + {"Longitude", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, + {"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, +} + +func TestColumnsStringGeneration(t *testing.T) { + if dbType == "postgres" || dbType == "mssql" { + return + } + + var statement *Statement + + for ndx, testCase := range colStrTests { + statement = createTestStatement() + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) + } + + columns := statement.RefTable.Columns() + if testCase.onlyToDBColumnNdx >= 0 { + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB + } + + actual := statement.genColumnStr() + + if actual != testCase.expected { + t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) + } + if testCase.onlyToDBColumnNdx >= 0 { + columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES + } + } +} + +func BenchmarkColumnsStringGeneration(b *testing.B) { + b.StopTimer() + + statement := createTestStatement() + + testCase := colStrTests[0] + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + } + + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() + + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } + } +} + +func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { + + b.StopTimer() + + mapCols := make(map[string]bool) + cols := []*schemas.Column{ + {Name: `ID`}, + {Name: `IsDeleted`}, + {Name: `Caption`}, + {Name: `Code1`}, + {Name: `Code2`}, + {Name: `Code3`}, + {Name: `ParentID`}, + {Name: `Latitude`}, + {Name: `Longitude`}, + } + + for _, col := range cols { + mapCols[strings.ToLower(col.Name)] = true + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + + for _, col := range cols { + + if _, ok := getFlagForColumn(mapCols, col); !ok { + b.Fatal("Unexpected result") + } + } + } +} + +func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { + + b.StopTimer() + + mapCols := make(map[string]bool) + cols := []*schemas.Column{ + {Name: `ID`}, + {Name: `IsDeleted`}, + {Name: `Caption`}, + {Name: `Code1`}, + {Name: `Code2`}, + {Name: `Code3`}, + {Name: `ParentID`}, + {Name: `Latitude`}, + {Name: `Longitude`}, + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + + for _, col := range cols { + + if _, ok := getFlagForColumn(mapCols, col); ok { + b.Fatal("Unexpected result") + } + } + } +} + +type TestType struct { + ID int64 `xorm:"ID PK"` + IsDeleted bool `xorm:"IsDeleted"` + Caption string `xorm:"Caption"` + Code1 string `xorm:"Code1"` + Code2 string `xorm:"Code2"` + Code3 string `xorm:"Code3"` + ParentID int64 `xorm:"ParentID"` + Latitude float64 `xorm:"Latitude"` + Longitude float64 `xorm:"Longitude"` +} + +func (TestType) TableName() string { + return "TestTable" +} + +func createTestStatement() *Statement { + if engine, ok := testEngine.(*Engine); ok { + statement := &Statement{} + statement.Reset() + statement.Engine = engine + statement.dialect = engine.dialect + statement.SetRefValue(reflect.ValueOf(TestType{})) + + return statement + } else if eg, ok := testEngine.(*EngineGroup); ok { + statement := &Statement{} + statement.Reset() + statement.Engine = eg.Engine + statement.dialect = eg.Engine.dialect + statement.SetRefValue(reflect.ValueOf(TestType{})) + + return statement + } + return nil +} diff --git a/types.go b/internal/statements/types.go similarity index 94% rename from types.go rename to internal/statements/types.go index ee725dae..0ff36f35 100644 --- a/types.go +++ b/internal/statements/types.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package xorm +package statements import ( "reflect" diff --git a/internal/statements/update.go b/internal/statements/update.go new file mode 100644 index 00000000..a5d7ec5a --- /dev/null +++ b/internal/statements/update.go @@ -0,0 +1,280 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "database/sql/driver" + "fmt" + "reflect" + "time" + + "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +// BuildUpdates auto generating update columnes and values according a struct +func (statement *Statement) BuildUpdates(bean interface{}, + includeVersion, includeUpdated, includeNil, + includeAutoIncr, update bool) ([]string, []interface{}, error) { + //engine := statement.Engine + table := statement.RefTable + allUseBool := statement.allUseBool + useAllCols := statement.useAllCols + mustColumnMap := statement.MustColumnMap + nullableMap := statement.NullableMap + columnMap := statement.ColumnMap + omitColumnMap := statement.OmitColumnMap + unscoped := statement.unscoped + + var colNames = make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if col.IsCreated && !columnMap.Contain(col.Name) { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + if col.IsDeleted && !unscoped { + continue + } + if omitColumnMap.Contain(col.Name) { + continue + } + if len(columnMap) > 0 && !columnMap.Contain(col.Name) { + continue + } + + if col.MapType == schemas.ONLYFROMDB { + continue + } + + if statement.IncrColumns.IsColExist(col.Name) { + continue + } else if statement.DecrColumns.IsColExist(col.Name) { + continue + } else if statement.ExprColumns.IsColExist(col.Name) { + continue + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + return nil, nil, err + } + + fieldValue := *fieldValuePtr + fieldType := reflect.TypeOf(fieldValue.Interface()) + if fieldType == nil { + continue + } + + requiredField := useAllCols + includeNil := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if b, ok := getFlagForColumn(nullableMap, col); ok { + if b && col.Nullable && utils.IsZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + fieldType = reflect.TypeOf(fieldValue.Interface()) + includeNil = true + } + } + + var val interface{} + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + return nil, nil, err + } + + val = data + goto APPEND + } + } + + if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + return nil, nil, err + } + + val = data + goto APPEND + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) + } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = nulType.Value() + } else { + if !col.SQLType.IsJson() { + table, err := statement.tagParser.MapType(fieldValue) + if err != nil { + val = fieldValue.Interface() + } else { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + if pkField.IsValid() && (!requiredField && !utils.IsZero(pkField.Interface())) { + val = pkField.Interface() + } else { + continue + } + } else { + // TODO: how to handler? + panic("not supported") + } + } + } else { + // Blank struct could not be as update data + if requiredField || !utils.IsStructZero(fieldValue) { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) + } + if col.SQLType.IsText() { + val = string(bytes) + } else if col.SQLType.IsBlob() { + val = bytes + } + } else { + continue + } + } + } + case reflect.Array, reflect.Slice, reflect.Map: + if !requiredField { + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldType.Kind() == reflect.Array { + if utils.IsArrayZero(fieldValue) { + continue + } + } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, err + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if fieldType.Kind() == reflect.Slice && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else if fieldType.Kind() == reflect.Array && + fieldType.Elem().Kind() == reflect.Uint8 { + val = fieldValue.Slice(0, 0).Interface() + } else { + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, nil, err + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + APPEND: + args = append(args, val) + if col.IsPrimaryKey { + continue + } + colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name))) + } + + return colNames, args, nil +} diff --git a/internal/utils/name.go b/internal/utils/name.go new file mode 100644 index 00000000..f5fc3ff7 --- /dev/null +++ b/internal/utils/name.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "fmt" +) + +func IndexName(tableName, idxName string) string { + return fmt.Sprintf("IDX_%v_%v", tableName, idxName) +} diff --git a/internal/utils/reflect.go b/internal/utils/reflect.go new file mode 100644 index 00000000..3dad6bfe --- /dev/null +++ b/internal/utils/reflect.go @@ -0,0 +1,13 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "reflect" +) + +func ReflectValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} diff --git a/internal/utils/sql.go b/internal/utils/sql.go new file mode 100644 index 00000000..5e68c4a4 --- /dev/null +++ b/internal/utils/sql.go @@ -0,0 +1,19 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "strings" +) + +func IsSubQuery(tbName string) bool { + const selStr = "select" + if len(tbName) <= len(selStr)+1 { + return false + } + + return strings.EqualFold(tbName[:len(selStr)], selStr) || + strings.EqualFold(tbName[:len(selStr)+1], "("+selStr) +} diff --git a/internal/utils/strings.go b/internal/utils/strings.go new file mode 100644 index 00000000..b5dc37b7 --- /dev/null +++ b/internal/utils/strings.go @@ -0,0 +1,30 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package utils + +import ( + "strings" +) + +func IndexNoCase(s, sep string) int { + return strings.Index(strings.ToLower(s), strings.ToLower(sep)) +} + +func SplitNoCase(s, sep string) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.Split(s, s[idx:idx+len(sep)]) +} + +func SplitNNoCase(s, sep string, n int) []string { + idx := IndexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.SplitN(s, s[idx:idx+len(sep)], n) +} + diff --git a/rows.go b/rows.go index b52b889d..e14c9894 100644 --- a/rows.go +++ b/rows.go @@ -10,6 +10,7 @@ import ( "reflect" "xorm.io/xorm/core" + "xorm.io/xorm/internal/utils" ) // Rows rows wrapper a rows to @@ -29,7 +30,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var args []interface{} var err error - if err = rows.session.statement.setRefBean(bean); err != nil { + if err = rows.session.statement.SetRefBean(bean); err != nil { return nil, err } @@ -38,7 +39,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { } if rows.session.statement.RawSQL == "" { - sqlStr, args, err = rows.session.statement.genGetSQL(bean) + sqlStr, args, err = rows.session.statement.GenGetSQL(bean) if err != nil { return nil, err } @@ -84,7 +85,7 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - if err := rows.session.statement.setRefBean(bean); err != nil { + if err := rows.session.statement.SetRefBean(bean); err != nil { return err } @@ -98,7 +99,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - dataStruct := rValue(bean) + dataStruct := utils.ReflectValue(bean) _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err diff --git a/schemas/quote.go b/schemas/quote.go index 0e022240..736b774a 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -109,6 +109,40 @@ func (q Quoter) Join(a []string, sep string) string { return b.String() } +func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { + if len(a) == 0 { + return nil + } + + n := len(sep) * (len(a) - 1) + for i := 0; i < len(a); i++ { + n += len(a[i]) + } + + b.Grow(n) + for i, s := range a { + if i > 0 { + if _, err := b.WriteString(sep); err != nil { + return err + } + } + if q[0] != "" && s != "*" && s[0] != '`' { + if _, err := b.WriteString(q[0]); err != nil { + return err + } + } + if _, err := b.WriteString(strings.TrimSpace(s)); err != nil { + return err + } + if q[1] != "" && s != "*" && s[0] != '`' { + if _, err := b.WriteString(q[1]); err != nil { + return err + } + } + } + return nil +} + func (q Quoter) Strings(s []string) []string { var res = make([]string, 0, len(s)) for _, a := range s { diff --git a/session.go b/session.go index 0b0f56c0..92063882 100644 --- a/session.go +++ b/session.go @@ -14,8 +14,11 @@ import ( "strings" "time" + "xorm.io/xorm/contexts" "xorm.io/xorm/convert" "xorm.io/xorm/core" + "xorm.io/xorm/internal/json" + "xorm.io/xorm/internal/statements" "xorm.io/xorm/schemas" ) @@ -32,7 +35,7 @@ type Session struct { db *core.DB engine *Engine tx *core.Tx - statement Statement + statement *statements.Statement isAutoCommit bool isCommitedOrRollbacked bool isAutoClose bool @@ -73,9 +76,12 @@ func (session *Session) Clone() *Session { // Init reset the session as the init status. func (session *Session) Init() { - session.statement.Reset() - session.statement.dialect = session.engine.dialect - session.statement.Engine = session.engine + session.statement = statements.NewStatement( + session.engine.dialect, + session.engine.tagParser, + session.engine.DatabaseTZ, + ) + session.showSQL = session.engine.showSQL session.isAutoCommit = true session.isCommitedOrRollbacked = false @@ -118,8 +124,8 @@ func (session *Session) Close() { } // ContextCache enable context cache or not -func (session *Session) ContextCache(context ContextCache) *Session { - session.statement.context = context +func (session *Session) ContextCache(context contexts.ContextCache) *Session { + session.statement.SetContextCache(context) return session } @@ -158,7 +164,9 @@ func (session *Session) After(closures func(interface{})) *Session { // Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { - session.statement.Table(tableNameOrBean) + if err := session.statement.SetTable(tableNameOrBean); err != nil { + session.engine.logger.Error(err) + } return session } @@ -182,7 +190,7 @@ func (session *Session) ForUpdate() *Session { // NoAutoCondition disable generate SQL condition from beans func (session *Session) NoAutoCondition(no ...bool) *Session { - session.statement.NoAutoCondition(no...) + session.statement.SetNoAutoCondition(no...) return session } @@ -288,7 +296,7 @@ func (session *Session) canCache() bool { !session.statement.UseCache || session.statement.IsForUpdate || session.tx != nil || - len(session.statement.selectStr) > 0 { + len(session.statement.SelectStr) > 0 { return false } return true @@ -505,13 +513,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -535,13 +543,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true if len(bs) > 0 { if fieldValue.CanAddr() { - err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) if err != nil { return nil, err } } else { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface()) if err != nil { return nil, err } @@ -557,7 +565,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true if col.SQLType.IsText() { x := reflect.New(fieldType) - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -672,7 +680,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) if err != nil { return nil, err } @@ -682,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true x := reflect.New(fieldType) if len(vv.Bytes()) > 0 { - err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) if err != nil { return nil, err } @@ -818,7 +826,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b case schemas.Complex64Type: var x complex64 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } @@ -828,7 +836,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b case schemas.Complex128Type: var x complex128 if len([]byte(vv.String())) > 0 { - err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) + err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) if err != nil { return nil, err } @@ -877,7 +885,7 @@ func (session *Session) LastSQL() (string, []interface{}) { // Unscoped always disable struct tag "deleted" func (session *Session) Unscoped() *Session { - session.statement.Unscoped() + session.statement.SetUnscoped() return session } diff --git a/session_cols.go b/session_cols.go index 4f7dc6cf..ca3589ab 100644 --- a/session_cols.go +++ b/session_cols.go @@ -63,19 +63,6 @@ func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has boo return false, false } -func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0, len(columns)) - for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns -} - // Incr provides a query string like "count = count + 1" func (session *Session) Incr(column string, arg ...interface{}) *Session { session.statement.Incr(column, arg...) diff --git a/session_cond.go b/session_cond.go index 72e3abc3..25d17148 100644 --- a/session_cond.go +++ b/session_cond.go @@ -51,5 +51,5 @@ func (session *Session) NotIn(column string, args ...interface{}) *Session { // Conds returns session query conditions except auto bean conditions func (session *Session) Conds() builder.Cond { - return session.statement.cond + return session.statement.Conds() } diff --git a/session_convert.go b/session_convert.go index e7eabecc..735aefa6 100644 --- a/session_convert.go +++ b/session_convert.go @@ -15,6 +15,7 @@ import ( "time" "xorm.io/xorm/convert" + "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -108,7 +109,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val case reflect.Complex64, reflect.Complex128: x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { session.engine.logger.Error(err) return err @@ -122,7 +123,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val if col.SQLType.IsText() { x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { session.engine.logger.Error(err) return err @@ -135,7 +136,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val } else { x := reflect.New(fieldType) if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, x.Interface()) + err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { session.engine.logger.Error(err) return err @@ -264,7 +265,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val case schemas.Complex64Type.Kind(): var x complex64 if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, &x) + err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { session.engine.logger.Error(err) return err @@ -275,7 +276,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val case schemas.Complex128Type.Kind(): var x complex128 if len(data) > 0 { - err := DefaultJSONHandler.Unmarshal(data, &x) + err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { session.engine.logger.Error(err) return err @@ -615,14 +616,14 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. } if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { session.engine.logger.Error(err) return 0, err } return string(bytes), nil } else if col.SQLType.IsBlob() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { session.engine.logger.Error(err) return 0, err @@ -631,7 +632,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. } return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) case reflect.Complex64, reflect.Complex128: - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { session.engine.logger.Error(err) return 0, err @@ -643,7 +644,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. } if col.SQLType.IsText() { - bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { session.engine.logger.Error(err) return 0, err @@ -656,7 +657,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. (fieldValue.Type().Elem().Kind() == reflect.Uint8) { bytes = fieldValue.Bytes() } else { - bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { session.engine.logger.Error(err) return 0, err diff --git a/session_delete.go b/session_delete.go index 6bcb3852..f21151e1 100644 --- a/session_delete.go +++ b/session_delete.go @@ -23,7 +23,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -80,11 +80,11 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } @@ -98,7 +98,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - condSQL, condArgs, err := session.statement.genConds(bean) + condSQL, condArgs, err := session.statement.GenConds(bean) if err != nil { return 0, err } @@ -152,7 +152,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { var realSQL string argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled + if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled realSQL = deleteSQL copy(argsForCache, condArgs) argsForCache = append(condArgs, argsForCache...) diff --git a/session_exist.go b/session_exist.go index d5b0c1d8..e52c618e 100644 --- a/session_exist.go +++ b/session_exist.go @@ -4,89 +4,19 @@ package xorm -import ( - "errors" - "fmt" - "reflect" - - "xorm.io/builder" - "xorm.io/xorm/schemas" -) - // Exist returns true if the record exist otherwise return false func (session *Session) Exist(bean ...interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } - if session.statement.lastError != nil { - return false, session.statement.lastError + if session.statement.LastError != nil { + return false, session.statement.LastError } - var sqlStr string - var args []interface{} - var joinStr string - var err error - if session.statement.RawSQL == "" { - if len(bean) == 0 { - tableName := session.statement.TableName() - if len(tableName) <= 0 { - return false, ErrTableNotFound - } - - tableName = session.statement.quote(tableName) - if len(session.statement.JoinStr) > 0 { - joinStr = session.statement.JoinStr - } - - if session.statement.cond.IsValid() { - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return false, err - } - - if session.engine.dialect.DBType() == schemas.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL) - } else if session.engine.dialect.DBType() == schemas.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL) - } - args = condArgs - } else { - if session.engine.dialect.DBType() == schemas.MSSQL { - sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr) - } else if session.engine.dialect.DBType() == schemas.ORACLE { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr) - } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr) - } - args = []interface{}{} - } - } else { - beanValue := reflect.ValueOf(bean[0]) - if beanValue.Kind() != reflect.Ptr { - return false, errors.New("needs a pointer") - } - - if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefBean(bean[0]); err != nil { - return false, err - } - } - - if len(session.statement.TableName()) <= 0 { - return false, ErrTableNotFound - } - session.statement.Limit(1) - sqlStr, args, err = session.statement.genGetSQL(bean[0]) - if err != nil { - return false, err - } - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenExistSQL(bean...) + if err != nil { + return false, err } rows, err := session.queryRows(sqlStr, args...) diff --git a/session_find.go b/session_find.go index 6903c1b9..a3ba2c82 100644 --- a/session_find.go +++ b/session_find.go @@ -8,10 +8,10 @@ import ( "errors" "fmt" "reflect" - "strings" "xorm.io/builder" "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -53,8 +53,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } session.autoResetStatement = true - if session.statement.selectStr != "" { - session.statement.selectStr = "" + if session.statement.SelectStr != "" { + session.statement.SelectStr = "" } if session.statement.OrderStr != "" { session.statement.OrderStr = "" @@ -66,8 +66,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { defer session.resetStatement() - if session.statement.lastError != nil { - return session.statement.lastError + if session.statement.LastError != nil { + return session.statement.LastError } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) @@ -82,7 +82,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { pv := reflect.New(sliceElementType.Elem()) - if err := session.statement.setRefValue(pv); err != nil { + if err := session.statement.SetRefValue(pv); err != nil { return err } } else { @@ -90,7 +90,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceElementType.Kind() == reflect.Struct { pv := reflect.New(sliceElementType) - if err := session.statement.setRefValue(pv); err != nil { + if err := session.statement.SetRefValue(pv); err != nil { return err } } else { @@ -103,16 +103,16 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) var addedTableName = (len(session.statement.JoinStr) > 0) var autoCond builder.Cond if tp == tpStruct { - if !session.statement.noAutoCondition && len(condiBean) > 0 { + if !session.statement.NoAutoCondition && len(condiBean) > 0 { var err error - autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) + autoCond, err = session.statement.BuildConds(table, condiBean[0], true, true, false, true, addedTableName) if err != nil { return err } } else { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://gitea.com/xorm/xorm/issues/179 - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled var colName = session.engine.Quote(col.Name) if addedTableName { var nm = session.statement.TableName() @@ -122,70 +122,20 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) colName = session.engine.Quote(nm) + "." + colName } - autoCond = session.engine.CondDeleted(col) + autoCond = session.statement.CondDeleted(col) } } } - var sqlStr string - var args []interface{} - var err error - if session.statement.RawSQL == "" { - if len(session.statement.TableName()) <= 0 { - return ErrTableNotFound - } - - var columnStr = session.statement.columnStr() - if len(session.statement.selectStr) > 0 { - columnStr = session.statement.selectStr - } else { - if session.statement.JoinStr == "" { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr) - } else { - columnStr = session.statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - - session.statement.cond = session.statement.cond.And(autoCond) - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return err - } - - args = append(session.statement.joinArgs, condArgs...) - sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return err - } - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenFindSQL(autoCond) + if err != nil { + return err } if session.canCache() { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.IsDistinct && - !session.statement.unscoped { + !session.statement.GetUnscoped() { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err @@ -274,7 +224,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) - dataStruct := rValue(newValue.Interface()) + dataStruct := utils.ReflectValue(newValue.Interface()) tb, err := session.engine.tagParser.MapType(dataStruct) if err != nil { return err @@ -323,8 +273,8 @@ func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) erro func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if !session.canCache() || - indexNoCase(sqlStr, "having") != -1 || - indexNoCase(sqlStr, "group by") != -1 { + utils.IndexNoCase(sqlStr, "having") != -1 || + utils.IndexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } @@ -338,7 +288,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } diff --git a/session_find_test.go b/session_find_test.go index 8df3bc84..ad9d1668 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" ) @@ -299,13 +300,13 @@ func TestHaving(t *testing.T) { func TestOrderSameMapper(t *testing.T) { assert.NoError(t, prepareEngine()) - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) mapper := testEngine.GetTableMapper() testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(mapper) }() @@ -324,12 +325,12 @@ func TestOrderSameMapper(t *testing.T) { func TestHavingSameMapper(t *testing.T) { assert.NoError(t, prepareEngine()) - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) mapper := testEngine.GetTableMapper() testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(mapper) }() assertSync(t, new(Userinfo)) diff --git a/session_get.go b/session_get.go index d1e96958..f0fc016b 100644 --- a/session_get.go +++ b/session_get.go @@ -12,6 +12,7 @@ import ( "strconv" "xorm.io/xorm/caches" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -27,8 +28,8 @@ func (session *Session) Get(bean interface{}) (bool, error) { func (session *Session) get(bean interface{}) (bool, error) { defer session.resetStatement() - if session.statement.lastError != nil { - return false, session.statement.lastError + if session.statement.LastError != nil { + return false, session.statement.LastError } beanValue := reflect.ValueOf(bean) @@ -39,7 +40,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return false, err } } @@ -53,7 +54,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, ErrTableNotFound } session.statement.Limit(1) - sqlStr, args, err = session.statement.genGetSQL(bean) + sqlStr, args, err = session.statement.GenGetSQL(bean) if err != nil { return false, err } @@ -66,7 +67,7 @@ func (session *Session) get(bean interface{}) (bool, error) { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && - !session.statement.unscoped { + !session.statement.GetUnscoped() { has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err @@ -74,7 +75,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } } - context := session.statement.context + context := session.statement.Context if context != nil { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { @@ -244,7 +245,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, // close it before covert data rows.Close() - dataStruct := rValue(bean) + dataStruct := utils.ReflectValue(bean) _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) if err != nil { return true, err @@ -274,7 +275,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf for _, filter := range session.engine.dialect.Filters() { sqlStr = filter.Do(sqlStr) } - newsql := session.statement.convertIDSQL(sqlStr) + newsql := session.statement.ConvertIDSQL(sqlStr) if newsql == "" { return false, ErrCacheFailed } diff --git a/session_get_test.go b/session_get_test.go index f1e8c7f6..b7eac2b4 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/contexts" "xorm.io/xorm/schemas" ) @@ -417,7 +418,7 @@ func TestContextGet(t *testing.T) { sess := testEngine.NewSession() defer sess.Close() - context := NewMemoryContextCache() + context := contexts.NewMemoryContextCache() var c2 ContextGetStruct has, err := sess.ID(1).NoCache().ContextCache(context).Get(&c2) @@ -452,7 +453,7 @@ func TestContextGet2(t *testing.T) { _, err := testEngine.Insert(&ContextGetStruct2{Name: "1"}) assert.NoError(t, err) - context := NewMemoryContextCache() + context := contexts.NewMemoryContextCache() var c2 ContextGetStruct2 has, err := testEngine.ID(1).NoCache().ContextCache(context).Get(&c2) diff --git a/session_insert.go b/session_insert.go index 4e822247..2206ad05 100644 --- a/session_insert.go +++ b/session_insert.go @@ -113,7 +113,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error return 0, errors.New("could not insert a empty slice") } - if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { + if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil { return 0, err } @@ -163,10 +163,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { @@ -178,7 +178,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) var colName = col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { @@ -214,10 +214,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if col.IsDeleted { continue } - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { @@ -229,7 +229,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) var colName = col.Name session.afterClosures = append(session.afterClosures, func(bean interface{}) { @@ -329,7 +329,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { @@ -353,7 +353,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns colPlaces := strings.Repeat("?, ", len(colNames)) if exprs.Len() <= 0 && len(colPlaces) > 0 { colPlaces = colPlaces[0 : len(colPlaces)-2] @@ -385,25 +385,25 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { + if err := session.engine.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { return 0, err } - if session.statement.cond.IsValid() { + if session.statement.Conds().IsValid() { if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { return 0, err } - if err := session.statement.writeArgs(buf, args); err != nil { + if err := session.statement.WriteArgs(buf, args); err != nil { return 0, err } - if len(exprs.args) > 0 { + if len(exprs.Args) > 0 { if _, err := buf.WriteString(","); err != nil { return 0, err } } - if err := exprs.writeArgs(buf); err != nil { + if err := exprs.WriteArgs(buf); err != nil { return 0, err } @@ -411,7 +411,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if err := session.statement.cond.WriteTo(buf); err != nil { + if err := session.statement.Conds().WriteTo(buf); err != nil { return 0, err } } else { @@ -423,7 +423,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - if err := exprs.writeArgs(buf); err != nil { + if err := exprs.WriteArgs(buf); err != nil { return 0, err } @@ -482,7 +482,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { session.engine.logger.Error(err) @@ -523,7 +523,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { session.engine.logger.Error(err) @@ -564,7 +564,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.cacheInsert(tableName) - if table.Version != "" && session.statement.checkVersion { + if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { session.engine.logger.Error(err) @@ -637,19 +637,19 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac continue } - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } @@ -681,7 +681,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) @@ -698,7 +698,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { arg, err := session.value2Interface(col, fieldValue) @@ -724,9 +724,9 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } @@ -751,9 +751,9 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) { } var columns = make([]string, 0, len(m)) - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns for k := range m { - if !exprs.isColExist(k) { + if !exprs.IsColExist(k) { columns = append(columns, k) } } @@ -774,15 +774,15 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, ErrTableNotFound } - exprs := session.statement.exprColumns + exprs := session.statement.ExprColumns w := builder.NewWriter() // if insert where - if session.statement.cond.IsValid() { + if session.statement.Conds().IsValid() { if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { return 0, err } - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { return 0, err } @@ -790,15 +790,15 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, err } - if err := session.statement.writeArgs(w, args); err != nil { + if err := session.statement.WriteArgs(w, args); err != nil { return 0, err } - if len(exprs.args) > 0 { + if len(exprs.Args) > 0 { if _, err := w.WriteString(","); err != nil { return 0, err } - if err := exprs.writeArgs(w); err != nil { + if err := exprs.WriteArgs(w); err != nil { return 0, err } } @@ -807,7 +807,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, err } - if err := session.statement.cond.WriteTo(w); err != nil { + if err := session.statement.Conds().WriteTo(w); err != nil { return 0, err } } else { @@ -818,7 +818,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, return 0, err } - if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { + if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil { return 0, err } if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil { @@ -826,11 +826,11 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64, } w.Append(args...) - if len(exprs.args) > 0 { + if len(exprs.Args) > 0 { if _, err := w.WriteString(","); err != nil { return 0, err } - if err := exprs.writeArgs(w); err != nil { + if err := exprs.WriteArgs(w); err != nil { return 0, err } } diff --git a/session_iterate.go b/session_iterate.go index 4a3cc083..8cab8f48 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -6,6 +6,8 @@ package xorm import ( "reflect" + + "xorm.io/xorm/internal/utils" ) // IterFunc only use by Iterate @@ -25,11 +27,11 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { defer session.Close() } - if session.statement.lastError != nil { - return session.statement.lastError + if session.statement.LastError != nil { + return session.statement.LastError } - if session.statement.bufferSize > 0 { + if session.statement.BufferSize > 0 { return session.bufferIterate(bean, fun) } @@ -57,18 +59,18 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { // BufferSize sets the buffersize for iterate func (session *Session) BufferSize(size int) *Session { - session.statement.bufferSize = size + session.statement.BufferSize = size return session } func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { - var bufferSize = session.statement.bufferSize + var bufferSize = session.statement.BufferSize var pLimitN = session.statement.LimitN if pLimitN != nil && bufferSize > *pLimitN { bufferSize = *pLimitN } var start = session.statement.Start - v := rValue(bean) + v := utils.ReflectValue(bean) sliceType := reflect.SliceOf(v.Type()) var idx = 0 session.autoResetStatement = false diff --git a/session_query.go b/session_query.go index 1783e154..12136466 100644 --- a/session_query.go +++ b/session_query.go @@ -8,83 +8,19 @@ import ( "fmt" "reflect" "strconv" - "strings" "time" - "xorm.io/builder" "xorm.io/xorm/core" "xorm.io/xorm/schemas" ) -func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { - if len(sqlOrArgs) > 0 { - return convertSQLOrArgs(sqlOrArgs...) - } - - if session.statement.RawSQL != "" { - return session.statement.RawSQL, session.statement.RawParams, nil - } - - if len(session.statement.TableName()) <= 0 { - return "", nil, ErrTableNotFound - } - - var columnStr = session.statement.columnStr() - if len(session.statement.selectStr) > 0 { - columnStr = session.statement.selectStr - } else { - if session.statement.JoinStr == "" { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr) - } else { - columnStr = session.statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.statement.GroupByStr != "" { - columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr) - } else { - columnStr = "*" - } - } - } - if columnStr == "" { - columnStr = "*" - } - } - - if err := session.statement.processIDParam(); err != nil { - return "", nil, err - } - - condSQL, condArgs, err := builder.ToSQL(session.statement.cond) - if err != nil { - return "", nil, err - } - - args := append(session.statement.joinArgs, condArgs...) - sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } - // for mssql and use limit - qs := strings.Count(sqlStr, "?") - if len(args)*2 == qs { - args = append(args, args...) - } - - return sqlStr, args, nil -} - // Query runs a raw sql and return records as []map[string][]byte func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) { if session.isAutoClose { defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -233,7 +169,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -253,7 +189,7 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string, defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } @@ -306,7 +242,7 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i defer session.Close() } - sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) + sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...) if err != nil { return nil, err } diff --git a/session_raw.go b/session_raw.go index 51487779..efd74710 100644 --- a/session_raw.go +++ b/session_raw.go @@ -9,8 +9,8 @@ import ( "reflect" "time" - "xorm.io/builder" "xorm.io/xorm/core" + "xorm.io/xorm/internal/statements" ) func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { @@ -196,20 +196,6 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er return session.DB().ExecContext(session.ctx, sqlStr, args...) } -func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { - switch sqlOrArgs[0].(type) { - case string: - return sqlOrArgs[0].(string), sqlOrArgs[1:], nil - case *builder.Builder: - return sqlOrArgs[0].(*builder.Builder).ToSQL() - case builder.Builder: - bd := sqlOrArgs[0].(builder.Builder) - return bd.ToSQL() - } - - return "", nil, ErrUnSupportedType -} - // Exec raw sql func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { if session.isAutoClose { @@ -220,7 +206,7 @@ func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { return nil, ErrUnSupportedType } - sqlStr, args, err := convertSQLOrArgs(sqlOrArgs...) + sqlStr, args, err := statements.ConvertSQLOrArgs(sqlOrArgs...) if err != nil { return nil, err } diff --git a/session_schema.go b/session_schema.go index 05b24c91..0279ced7 100644 --- a/session_schema.go +++ b/session_schema.go @@ -33,11 +33,11 @@ func (session *Session) CreateTable(bean interface{}) error { } func (session *Session) createTable(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqlStr := session.statement.genCreateTableSQL() + sqlStr := session.statement.GenCreateTableSQL() _, err := session.exec(sqlStr) return err } @@ -52,11 +52,11 @@ func (session *Session) CreateIndexes(bean interface{}) error { } func (session *Session) createIndexes(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genIndexSQL() + sqls := session.statement.GenIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -75,11 +75,11 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createUniques(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genUniqueSQL() + sqls := session.statement.GenUniqueSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -99,11 +99,11 @@ func (session *Session) DropIndexes(bean interface{}) error { } func (session *Session) dropIndexes(bean interface{}) error { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return err } - sqls := session.statement.genDelIndexSQL() + sqls := session.statement.GenDelIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { @@ -201,7 +201,7 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) - sql := session.statement.dialect.AddColumnSQL(session.statement.TableName(), col) + sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col) _, err := session.exec(sql) return err } @@ -241,7 +241,7 @@ func (session *Session) Sync2(beans ...interface{}) error { }() for _, bean := range beans { - v := rValue(bean) + v := utils.ReflectValue(bean) table, err := engine.tagParser.MapType(v) if err != nil { return err @@ -299,7 +299,7 @@ func (session *Session) Sync2(beans ...interface{}) error { // column is not exist on table if oriCol == nil { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) if err = session.addColumn(col.Name); err != nil { return err } @@ -409,11 +409,11 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == schemas.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) err = session.addUnique(tbNameWithSchema, name) } else if index.Type == schemas.IndexType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.SetTableName(tbNameWithSchema) err = session.addIndex(tbNameWithSchema, name) } if err != nil { diff --git a/session_stats.go b/session_stats.go index c2cac830..17d0a675 100644 --- a/session_stats.go +++ b/session_stats.go @@ -17,17 +17,9 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { defer session.Close() } - var sqlStr string - var args []interface{} - var err error - if session.statement.RawSQL == "" { - sqlStr, args, err = session.statement.genCountSQL(bean...) - if err != nil { - return 0, err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenCountSQL(bean...) + if err != nil { + return 0, err } var total int64 @@ -50,21 +42,12 @@ func (session *Session) sum(res interface{}, bean interface{}, columnNames ...st return errors.New("need a pointer to a variable") } - var isSlice = v.Elem().Kind() == reflect.Slice - var sqlStr string - var args []interface{} - var err error - if len(session.statement.RawSQL) == 0 { - sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...) - if err != nil { - return err - } - } else { - sqlStr = session.statement.RawSQL - args = session.statement.RawParams + sqlStr, args, err := session.statement.GenSumSQL(bean, columnNames...) + if err != nil { + return err } - if isSlice { + if v.Elem().Kind() == reflect.Slice { err = session.queryRow(sqlStr, args...).ScanSlice(res) } else { err = session.queryRow(sqlStr, args...).Scan(res) diff --git a/session_tx_test.go b/session_tx_test.go index 1e3dcabf..303fd8d6 100644 --- a/session_tx_test.go +++ b/session_tx_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" ) @@ -85,10 +86,10 @@ func TestCombineTransactionSameMapper(t *testing.T) { assert.NoError(t, prepareEngine()) oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) testEngine.SetMapper(oldMapper) }() diff --git a/session_update.go b/session_update.go index 4330afae..bb53c3a1 100644 --- a/session_update.go +++ b/session_update.go @@ -23,7 +23,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri return ErrCacheFailed } - oldhead, newsql := session.statement.convertUpdateSQL(sqlStr) + oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -88,12 +88,12 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri return err } if bean := cacher.GetBean(tableName, sid); bean != nil { - sqls := splitNNoCase(sqlStr, "where", 2) + sqls := utils.SplitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } - sqls = splitNNoCase(sqls[0], "set", 2) + sqls = utils.SplitNNoCase(sqls[0], "set", 2) if len(sqls) != 2 { return ErrCacheFailed } @@ -112,7 +112,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri session.engine.logger.Error(err) } else { session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.statement.checkVersion { + if col.IsVersion && session.statement.CheckVersion { session.incrVersionFieldValue(fieldValue) } else { fieldValue.Set(reflect.ValueOf(args[idx])) @@ -144,11 +144,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } - if session.statement.lastError != nil { - return 0, session.statement.lastError + if session.statement.LastError != nil { + return 0, session.statement.LastError } - v := rValue(bean) + v := utils.ReflectValue(bean) t := v.Type() var colNames []string @@ -168,7 +168,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefBean(bean); err != nil { + if err := session.statement.SetRefBean(bean); err != nil { return 0, err } @@ -176,14 +176,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrTableNotFound } - if session.statement.columnStr() == "" { - colNames, args = session.statement.buildUpdates(bean, false, false, + if session.statement.ColumnStr() == "" { + colNames, args, err = session.statement.BuildUpdates(bean, false, false, false, false, true) } else { colNames, args, err = session.genUpdateColumns(bean) - if err != nil { - return 0, err - } + } + if err != nil { + return 0, err } } else if isMap { colNames = make([]string, 0) @@ -201,8 +201,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table := session.statement.RefTable if session.statement.UseAutoTime && table != nil && table.Updated != "" { - if !session.statement.columnMap.contain(table.Updated) && - !session.statement.omitColumnMap.contain(table.Updated) { + if !session.statement.ColumnMap.Contain(table.Updated) && + !session.statement.OmitColumnMap.Contain(table.Updated) { colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") col := table.UpdatedColumn() val, t := session.engine.nowTime(col) @@ -219,21 +219,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } // for update action to like "column = column + ?" - incColumns := session.statement.incrColumns - for i, colName := range incColumns.colNames { + incColumns := session.statement.IncrColumns + for i, colName := range incColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") - args = append(args, incColumns.args[i]) + args = append(args, incColumns.Args[i]) } // for update action to like "column = column - ?" - decColumns := session.statement.decrColumns - for i, colName := range decColumns.colNames { + decColumns := session.statement.DecrColumns + for i, colName := range decColumns.ColNames { colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") - args = append(args, decColumns.args[i]) + args = append(args, decColumns.Args[i]) } // for update action to like "column = expression" - exprColumns := session.statement.exprColumns - for i, colName := range exprColumns.colNames { - switch tp := exprColumns.args[i].(type) { + exprColumns := session.statement.ExprColumns + for i, colName := range exprColumns.ColNames { + switch tp := exprColumns.Args[i].(type) { case string: if len(tp) == 0 { tp = "''" @@ -248,16 +248,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, subArgs...) default: colNames = append(colNames, session.engine.Quote(colName)+"=?") - args = append(args, exprColumns.args[i]) + args = append(args, exprColumns.Args[i]) } } - if err = session.statement.processIDParam(); err != nil { + if err = session.statement.ProcessIDParam(); err != nil { return 0, err } var autoCond builder.Cond - if !session.statement.noAutoCondition { + if !session.statement.NoAutoCondition { condBeanIsStruct := false if len(condiBean) > 0 { if c, ok := condiBean[0].(map[string]interface{}); ok { @@ -270,7 +270,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if k == reflect.Struct { var err error - autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) + autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) if err != nil { return 0, err } @@ -282,8 +282,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if !condBeanIsStruct && table != nil { - if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled - autoCond1 := session.engine.CondDeleted(col) + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond1 := session.statement.CondDeleted(col) if autoCond == nil { autoCond = autoCond1 @@ -294,15 +294,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - st := &session.statement + st := session.statement var ( sqlStr string condArgs []interface{} condSQL string - cond = session.statement.cond.And(autoCond) + cond = session.statement.Conds().And(autoCond) - doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.checkVersion) + doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) verValue *reflect.Value ) if doIncVer { @@ -335,9 +335,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var top string if st.LimitN != nil { limitValue := *st.LimitN - if st.dialect.DBType() == schemas.MYSQL { + if session.engine.dialect.DBType() == schemas.MYSQL { condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) - } else if st.dialect.DBType() == schemas.SQLITE { + } else if session.engine.dialect.DBType() == schemas.SQLITE { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) @@ -348,7 +348,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.dialect.DBType() == schemas.POSTGRES { + } else if session.engine.dialect.DBType() == schemas.POSTGRES { tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.engine.Quote(tableName), tempCondSQL), condArgs...)) @@ -360,8 +360,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } - } else if st.dialect.DBType() == schemas.MSSQL { - if st.OrderStr != "" && st.dialect.DBType() == schemas.MSSQL && + } else if session.engine.dialect.DBType() == schemas.MSSQL { + if st.OrderStr != "" && session.engine.dialect.DBType() == schemas.MSSQL && table != nil && len(table.PrimaryKeys) == 1 { cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], @@ -459,7 +459,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac for _, col := range table.Columns() { if !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if session.statement.omitColumnMap.contain(col.Name) { + if session.statement.OmitColumnMap.Contain(col.Name) { continue } } @@ -494,25 +494,25 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac } } - if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated { + if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated { continue } // if only update specify columns - if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { + if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) { continue } - if session.statement.incrColumns.isColExist(col.Name) { + if session.statement.IncrColumns.IsColExist(col.Name) { continue - } else if session.statement.decrColumns.isColExist(col.Name) { + } else if session.statement.DecrColumns.IsColExist(col.Name) { continue - } else if session.statement.exprColumns.isColExist(col.Name) { + } else if session.statement.ExprColumns.IsColExist(col.Name) { continue } // !evalphobia! set fieldValue as nil when column is nullable and zero-value - if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { + if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok { if col.Nullable && utils.IsValueZero(fieldValue) { var nilValue *int fieldValue = reflect.ValueOf(nilValue) @@ -529,7 +529,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac col := table.GetColumn(colName) setColumnTime(bean, col, t) }) - } else if col.IsVersion && session.statement.checkVersion { + } else if col.IsVersion && session.statement.CheckVersion { args = append(args, 1) } else { arg, err := session.value2Interface(col, fieldValue) diff --git a/session_update_test.go b/session_update_test.go index 2d310aa1..0ef59155 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" ) @@ -685,20 +686,20 @@ func TestUpdateSameMapper(t *testing.T) { assert.NoError(t, prepareEngine()) oldMapper := testEngine.GetTableMapper() - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.UnMapType(rValue(new(Condi)).Type()) - testEngine.UnMapType(rValue(new(Article)).Type()) - testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Article)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type()) testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(Userinfo)).Type()) - testEngine.UnMapType(rValue(new(Condi)).Type()) - testEngine.UnMapType(rValue(new(Article)).Type()) - testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) - testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(Article)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type()) testEngine.SetMapper(oldMapper) }() diff --git a/statement_test.go b/statement_test.go index 6e5564b0..57d6e477 100644 --- a/statement_test.go +++ b/statement_test.go @@ -5,185 +5,11 @@ package xorm import ( - "reflect" - "strings" "testing" "github.com/stretchr/testify/assert" - "xorm.io/xorm/schemas" ) -var colStrTests = []struct { - omitColumn string - onlyToDBColumnNdx int - expected string -}{ - {"", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, - {"Code2", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, - {"", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, - {"Code3", 1, "`ID`, `Caption`, `Code1`, `Code2`, `ParentID`, `Latitude`, `Longitude`"}, - {"Longitude", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, - {"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, -} - -func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { - return - } - - var statement *Statement - - for ndx, testCase := range colStrTests { - statement = createTestStatement() - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) - } - - columns := statement.RefTable.Columns() - if testCase.onlyToDBColumnNdx >= 0 { - columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB - } - - actual := statement.genColumnStr() - - if actual != testCase.expected { - t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) - } - if testCase.onlyToDBColumnNdx >= 0 { - columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES - } - } -} - -func BenchmarkColumnsStringGeneration(b *testing.B) { - b.StopTimer() - - statement := createTestStatement() - - testCase := colStrTests[0] - - if testCase.omitColumn != "" { - statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped - } - - if testCase.onlyToDBColumnNdx >= 0 { - columns := statement.RefTable.Columns() - columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - actual := statement.genColumnStr() - - if actual != testCase.expected { - b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) - } - } -} - -func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) { - - b.StopTimer() - - mapCols := make(map[string]bool) - cols := []*schemas.Column{ - {Name: `ID`}, - {Name: `IsDeleted`}, - {Name: `Caption`}, - {Name: `Code1`}, - {Name: `Code2`}, - {Name: `Code3`}, - {Name: `ParentID`}, - {Name: `Latitude`}, - {Name: `Longitude`}, - } - - for _, col := range cols { - mapCols[strings.ToLower(col.Name)] = true - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - - for _, col := range cols { - - if _, ok := getFlagForColumn(mapCols, col); !ok { - b.Fatal("Unexpected result") - } - } - } -} - -func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) { - - b.StopTimer() - - mapCols := make(map[string]bool) - cols := []*schemas.Column{ - {Name: `ID`}, - {Name: `IsDeleted`}, - {Name: `Caption`}, - {Name: `Code1`}, - {Name: `Code2`}, - {Name: `Code3`}, - {Name: `ParentID`}, - {Name: `Latitude`}, - {Name: `Longitude`}, - } - - b.StartTimer() - - for i := 0; i < b.N; i++ { - - for _, col := range cols { - - if _, ok := getFlagForColumn(mapCols, col); ok { - b.Fatal("Unexpected result") - } - } - } -} - -type TestType struct { - ID int64 `xorm:"ID PK"` - IsDeleted bool `xorm:"IsDeleted"` - Caption string `xorm:"Caption"` - Code1 string `xorm:"Code1"` - Code2 string `xorm:"Code2"` - Code3 string `xorm:"Code3"` - ParentID int64 `xorm:"ParentID"` - Latitude float64 `xorm:"Latitude"` - Longitude float64 `xorm:"Longitude"` -} - -func (TestType) TableName() string { - return "TestTable" -} - -func createTestStatement() *Statement { - if engine, ok := testEngine.(*Engine); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = engine - statement.dialect = engine.dialect - statement.setRefValue(reflect.ValueOf(TestType{})) - - return statement - } else if eg, ok := testEngine.(*EngineGroup); ok { - statement := &Statement{} - statement.Reset() - statement.Engine = eg.Engine - statement.dialect = eg.Engine.dialect - statement.setRefValue(reflect.ValueOf(TestType{})) - - return statement - } - return nil -} - func TestDistinctAndCols(t *testing.T) { type DistinctAndCols struct { Id int64 diff --git a/tags_test.go b/tags_test.go index b8a43670..9d41a5fa 100644 --- a/tags_test.go +++ b/tags_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "xorm.io/xorm/internal/utils" "xorm.io/xorm/names" "xorm.io/xorm/schemas" ) @@ -608,10 +609,10 @@ func TestGonicMapperID(t *testing.T) { assert.NoError(t, prepareEngine()) oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type()) testEngine.SetMapper(names.LintGonicMapper) defer func() { - testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type()) testEngine.SetMapper(oldMapper) }() @@ -645,10 +646,10 @@ func TestSameMapperID(t *testing.T) { assert.NoError(t, prepareEngine()) oldMapper := testEngine.GetColumnMapper() - testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type()) testEngine.SetMapper(names.SameMapper{}) defer func() { - testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) + testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type()) testEngine.SetMapper(oldMapper) }() @@ -818,7 +819,9 @@ func TestAutoIncrTag(t *testing.T) { Id int64 } - tb := testEngine.TableInfo(new(TestAutoIncr1)) + tb, err := testEngine.TableInfo(new(TestAutoIncr1)) + assert.NoError(t, err) + cols := tb.Columns() assert.EqualValues(t, 1, len(cols)) assert.True(t, cols[0].IsAutoIncrement) @@ -829,7 +832,9 @@ func TestAutoIncrTag(t *testing.T) { Id int64 `xorm:"id"` } - tb = testEngine.TableInfo(new(TestAutoIncr2)) + tb, err = testEngine.TableInfo(new(TestAutoIncr2)) + assert.NoError(t, err) + cols = tb.Columns() assert.EqualValues(t, 1, len(cols)) assert.False(t, cols[0].IsAutoIncrement) @@ -840,7 +845,9 @@ func TestAutoIncrTag(t *testing.T) { Id int64 `xorm:"'ID'"` } - tb = testEngine.TableInfo(new(TestAutoIncr3)) + tb, err = testEngine.TableInfo(new(TestAutoIncr3)) + assert.NoError(t, err) + cols = tb.Columns() assert.EqualValues(t, 1, len(cols)) assert.False(t, cols[0].IsAutoIncrement) @@ -851,7 +858,9 @@ func TestAutoIncrTag(t *testing.T) { Id int64 `xorm:"pk"` } - tb = testEngine.TableInfo(new(TestAutoIncr4)) + tb, err = testEngine.TableInfo(new(TestAutoIncr4)) + assert.NoError(t, err) + cols = tb.Columns() assert.EqualValues(t, 1, len(cols)) assert.False(t, cols[0].IsAutoIncrement) @@ -1035,7 +1044,9 @@ func TestTagDefault5(t *testing.T) { } assertSync(t, new(DefaultStruct5)) - table := testEngine.TableInfo(new(DefaultStruct5)) + table, err := testEngine.TableInfo(new(DefaultStruct5)) + assert.NoError(t, err) + createdCol := table.GetColumn("created") assert.NotNil(t, createdCol) assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) diff --git a/types_test.go b/types_test.go index 53872372..d8fd8309 100644 --- a/types_test.go +++ b/types_test.go @@ -10,6 +10,7 @@ import ( "testing" "xorm.io/xorm/convert" + "xorm.io/xorm/internal/json" "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" @@ -118,21 +119,21 @@ type ConvConfig struct { } func (s *ConvConfig) FromDB(data []byte) error { - return DefaultJSONHandler.Unmarshal(data, s) + return json.DefaultJSONHandler.Unmarshal(data, s) } func (s *ConvConfig) ToDB() ([]byte, error) { - return DefaultJSONHandler.Marshal(s) + return json.DefaultJSONHandler.Marshal(s) } type SliceType []*ConvConfig func (s *SliceType) FromDB(data []byte) error { - return DefaultJSONHandler.Unmarshal(data, s) + return json.DefaultJSONHandler.Unmarshal(data, s) } func (s *SliceType) ToDB() ([]byte, error) { - return DefaultJSONHandler.Marshal(s) + return json.DefaultJSONHandler.Marshal(s) } type ConvStruct struct {