Move statement as a sub package (#1564)

Fix test

Fix bug

Move statement as a sub package

Reviewed-on: https://gitea.com/xorm/xorm/pulls/1564
This commit is contained in:
Lunny Xiao 2020-02-28 12:29:08 +00:00
parent f63b42ff9b
commit 2b62dc5a51
50 changed files with 1995 additions and 1792 deletions

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package contexts
// ContextCache is the interface that operates the cache data. // ContextCache is the interface that operates the cache data.
type ContextCache interface { type ContextCache interface {

View File

@ -41,6 +41,7 @@ type Dialect interface {
DBType() DBType DBType() DBType
SQLType(*schemas.Column) string SQLType(*schemas.Column) string
FormatBytes(b []byte) string FormatBytes(b []byte) string
DefaultSchema() string
DriverName() string DriverName() string
DataSourceName() string DataSourceName() string
@ -103,6 +104,10 @@ func (b *Base) SetLogger(logger log.Logger) {
b.logger = logger b.logger = logger
} }
func (b *Base) DefaultSchema() string {
return ""
}
func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error { func (b *Base) Init(db *core.DB, dialect Dialect, uri *URI, drivername, dataSourceName string) error {
b.db, b.dialect, b.uri = db, dialect, uri b.db, b.dialect, b.uri = db, dialect, uri
b.driverName, b.dataSourceName = drivername, dataSourceName b.driverName, b.dataSourceName = drivername, dataSourceName

View File

@ -788,6 +788,10 @@ func (db *postgres) Init(d *core.DB, uri *URI, drivername, dataSourceName string
return nil return nil
} }
func (db *postgres) DefaultSchema() string {
return PostgresPublicSchema
}
func (db *postgres) SQLType(c *schemas.Column) string { func (db *postgres) SQLType(c *schemas.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {

90
dialects/table_name.go Normal file
View File

@ -0,0 +1,90 @@
// Copyright 2015 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"fmt"
"reflect"
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names"
)
// TableNameWithSchema will add schema prefix on table name if possible
func TableNameWithSchema(dialect Dialect, tableName string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if dialect.URI().Schema != "" &&
dialect.URI().Schema != dialect.DefaultSchema() &&
strings.Index(tableName, ".") == -1 {
return fmt.Sprintf("%s.%s", dialect.URI().Schema, tableName)
}
return tableName
}
// TableNameNoSchema returns table name with given tableName
func TableNameNoSchema(dialect Dialect, mapper names.Mapper, tableName interface{}) string {
quote := dialect.Quoter().Quote
switch tableName.(type) {
case []string:
t := tableName.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", quote(t[0]), quote(t[1]))
} else if len(t) == 1 {
return quote(t[0])
}
case []interface{}:
t := tableName.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case names.TableName:
table = f.(names.TableName).TableName()
default:
v := utils.ReflectValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = names.GetTableName(mapper, v)
} else {
table = quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", quote(table), quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return quote(table)
}
case names.TableName:
return tableName.(names.TableName).TableName()
case string:
return tableName.(string)
case reflect.Value:
v := tableName.(reflect.Value)
return names.GetTableName(mapper, v)
default:
v := utils.ReflectValue(tableName)
t := v.Type()
if t.Kind() == reflect.Struct {
return names.GetTableName(mapper, v)
}
return quote(fmt.Sprintf("%v", tableName))
}
return ""
}
// FullTableName returns table name with quote and schema according parameter
func FullTableName(dialect Dialect, mapper names.Mapper, bean interface{}, includeSchema ...bool) string {
tbName := TableNameNoSchema(dialect, mapper, bean)
if len(includeSchema) > 0 && includeSchema[0] && !utils.IsSubQuery(tbName) {
tbName = TableNameWithSchema(dialect, tbName)
}
return tbName
}

View File

@ -2,11 +2,13 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package dialects
import ( import (
"testing" "testing"
"xorm.io/xorm/names"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -20,9 +22,9 @@ func (mcc *MCC) TableName() string {
return "mcc" return "mcc"
} }
func TestTableName1(t *testing.T) { func TestFullTableName(t *testing.T) {
assert.NoError(t, prepareEngine()) dialect := QueryDialect("mysql")
assert.EqualValues(t, "mcc", testEngine.TableName(new(MCC))) assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, &MCC{}))
assert.EqualValues(t, "mcc", testEngine.TableName("mcc")) assert.EqualValues(t, "mcc", FullTableName(dialect, names.SnakeMapper{}, "mcc"))
} }

49
dialects/time.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2015 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dialects
import (
"time"
"xorm.io/xorm/schemas"
)
// FormatTime format time as column type
func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
}
func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() {
if col.Nullable {
return nil
}
return ""
}
if col.TimeZone != nil {
return FormatTime(dialect, col.SQLType.Name, t.In(col.TimeZone))
}
return FormatTime(dialect, col.SQLType.Name, t.In(defaultTimeZone))
}

View File

@ -18,7 +18,6 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/builder"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
@ -65,25 +64,6 @@ func (engine *Engine) BufferSize(size int) *Session {
return session.BufferSize(size) return session.BufferSize(size)
} }
// CondDeleted returns the conditions whether a record is soft deleted.
func (engine *Engine) CondDeleted(col *schemas.Column) builder.Cond {
var cond = builder.NewCond()
if col.SQLType.IsNumeric() {
cond = builder.Eq{col.Name: 0}
} else {
// FIXME: mssql: The conversion of a nvarchar data type to a datetime data type resulted in an out-of-range value.
if engine.dialect.DBType() != schemas.MSSQL {
cond = builder.Eq{col.Name: utils.ZeroTime1}
}
}
if col.Nullable {
cond = cond.Or(builder.IsNull{col.Name})
}
return cond
}
// ShowSQL show SQL statement or not on logger if log level is great than INFO // ShowSQL show SQL statement or not on logger if log level is great than INFO
func (engine *Engine) ShowSQL(show ...bool) { func (engine *Engine) ShowSQL(show ...bool) {
engine.logger.ShowSQL(show...) engine.logger.ShowSQL(show...)
@ -237,7 +217,7 @@ func (engine *Engine) NoCascade() *Session {
// MapCacher Set a table use a special cacher // MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error { func (engine *Engine) MapCacher(bean interface{}, cacher caches.Cacher) error {
engine.SetCacher(engine.TableName(bean, true), cacher) engine.SetCacher(dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, true), cacher)
return nil return nil
} }
@ -759,13 +739,13 @@ func (t *Table) IsValid() bool {
} }
// TableInfo get table info according to bean's content // TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table { func (engine *Engine) TableInfo(bean interface{}) (*Table, error) {
v := rValue(bean) v := utils.ReflectValue(bean)
tb, err := engine.tagParser.MapType(v) tb, err := engine.tagParser.MapType(v)
if err != nil { if err != nil {
engine.logger.Error(err) return nil, err
} }
return &Table{tb, engine.TableName(bean)} return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil
} }
// IsTableEmpty if a table has any reocrd // IsTableEmpty if a table has any reocrd
@ -787,6 +767,11 @@ func (engine *Engine) IDOf(bean interface{}) schemas.PK {
return engine.IDOfV(reflect.ValueOf(bean)) return engine.IDOfV(reflect.ValueOf(bean))
} }
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
return dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean, includeSchema...)
}
// IDOfV get id from one value of struct // IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK { func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK {
pk, err := engine.idOfV(rv) pk, err := engine.idOfV(rv)
@ -873,7 +858,7 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
// ClearCacheBean if enabled cache, clear the cache bean // ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
tableName := engine.TableName(bean) tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName) cacher := engine.GetCacher(tableName)
if cacher != nil { if cacher != nil {
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
@ -885,7 +870,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
// ClearCache if enabled cache, clear some tables' cache // ClearCache if enabled cache, clear some tables' cache
func (engine *Engine) ClearCache(beans ...interface{}) error { func (engine *Engine) ClearCache(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
tableName := engine.TableName(bean) tableName := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
cacher := engine.GetCacher(tableName) cacher := engine.GetCacher(tableName)
if cacher != nil { if cacher != nil {
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
@ -908,8 +893,8 @@ func (engine *Engine) Sync(beans ...interface{}) error {
defer session.Close() defer session.Close()
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := utils.ReflectValue(bean)
tableNameNoSchema := engine.TableName(bean) tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)
table, err := engine.tagParser.MapType(v) table, err := engine.tagParser.MapType(v)
if err != nil { if err != nil {
return err return err
@ -946,7 +931,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
@ -957,7 +942,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
for name, index := range table.Indexes { for name, index := range table.Indexes {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
if index.Type == schemas.UniqueType { if index.Type == schemas.UniqueType {
@ -966,7 +951,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
@ -981,7 +966,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if !isExist { if !isExist {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
@ -1250,45 +1235,11 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) {
if !col.DisableTimeZone && col.TimeZone != nil { if !col.DisableTimeZone && col.TimeZone != nil {
tz = col.TimeZone tz = col.TimeZone
} }
return engine.formatTime(col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation)
} }
func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) { func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() { return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t)
if col.Nullable {
return nil
}
return ""
}
if col.TimeZone != nil {
return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone))
}
return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ))
}
// formatTime format time as column type
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case schemas.Time:
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case schemas.Date:
v = t.Format("2006-01-02")
case schemas.DateTime, schemas.TimeStamp, schemas.Varchar: // !DarthPestilane! format time when sqlTypeName is schemas.Varchar.
v = t.Format("2006-01-02 15:04:05")
case schemas.TimeStampz:
if engine.dialect.DBType() == schemas.MSSQL {
v = t.Format("2006-01-02T15:04:05.9999999Z07:00")
} else {
v = t.Format(time.RFC3339Nano)
}
case schemas.BigInt, schemas.Int:
v = t.Unix()
default:
v = t
}
return
} }
// GetColumnMapper returns the column name mapper // GetColumnMapper returns the column name mapper
@ -1332,3 +1283,7 @@ func (engine *Engine) Unscoped() *Session {
session.isAutoClose = true session.isAutoClose = true
return session.Unscoped() return session.Unscoped()
} }
func (engine *Engine) tbNameWithSchema(v string) string {
return dialects.TableNameWithSchema(engine.dialect, v)
}

View File

@ -1,234 +0,0 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql/driver"
"fmt"
"reflect"
"strings"
"time"
"xorm.io/builder"
"xorm.io/xorm/convert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (engine *Engine) buildConds(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) {
var conds []builder.Cond
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if engine.dialect.DBType() == schemas.MSSQL && (col.SQLType.Name == schemas.Text || col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue
}
if col.SQLType.IsJson() {
continue
}
var colName string
if addedTableName {
var nm = tableName
if len(aliasName) > 0 {
nm = aliasName
}
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
} else {
colName = engine.Quote(col.Name)
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
if !strings.Contains(err.Error(), "is not valid") {
engine.logger.Warn(err)
}
continue
}
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
conds = append(conds, engine.CondDeleted(col))
}
fieldValue := *fieldValuePtr
if fieldValue.Interface() == nil {
continue
}
fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}
if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
conds = append(conds, builder.Eq{colName: nil})
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}
var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = engine.formatColTime(col, t)
} else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok {
continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else {
if col.SQLType.IsJson() {
if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
table, err := engine.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !utils.IsZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys)
}
}
}
}
case reflect.Array:
continue
case reflect.Slice, reflect.Map:
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else {
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
engine.logger.Error(err)
continue
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
conds = append(conds, builder.Eq{colName: val})
}
return builder.And(conds...), nil
}

View File

@ -1,109 +0,0 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"fmt"
"reflect"
"strings"
"xorm.io/xorm/dialects"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
)
// tbNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == schemas.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != dialects.PostgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
func isSubQuery(tbName string) bool {
const selStr = "select"
if len(tbName) <= len(selStr)+1 {
return false
}
return strings.EqualFold(tbName[:len(selStr)], selStr) || strings.EqualFold(tbName[:len(selStr)+1], "("+selStr)
}
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] && !isSubQuery(tbName) {
tbName = engine.tbNameWithSchema(tbName)
}
return tbName
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *schemas.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
return engine.Quote(t[0])
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case names.TableName:
table = f.(names.TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = names.GetTableName(engine.GetTableMapper(), v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
return engine.Quote(table)
}
case names.TableName:
return tablename.(names.TableName).TableName()
case string:
return tablename.(string)
case reflect.Value:
v := tablename.(reflect.Value)
return names.GetTableName(engine.GetTableMapper(), v)
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return names.GetTableName(engine.GetTableMapper(), v)
}
return engine.Quote(fmt.Sprintf("%v", tablename))
}
return ""
}

View File

@ -26,8 +26,6 @@ var (
ErrNotImplemented = errors.New("Not implemented") ErrNotImplemented = errors.New("Not implemented")
// ErrConditionType condition type unsupported // ErrConditionType condition type unsupported
ErrConditionType = errors.New("Unsupported condition type") ErrConditionType = errors.New("Unsupported condition type")
// ErrUnSupportedSQLType parameter of SQL is not supported
ErrUnSupportedSQLType = errors.New("Unsupported sql type")
) )
// ErrFieldIsNotExist columns does not exist // ErrFieldIsNotExist columns does not exist

View File

@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"time" "time"
) )
@ -138,26 +137,6 @@ func int64ToInt(id int64, tp reflect.Type) interface{} {
return int64ToIntValue(id, tp).Interface() return int64ToIntValue(id, tp).Interface()
} }
func indexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep))
}
func splitNoCase(s, sep string) []string {
idx := indexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.Split(s, s[idx:idx+len(sep)])
}
func splitNNoCase(s, sep string, n int) []string {
idx := indexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.SplitN(s, s[idx:idx+len(sep)], n)
}
func makeArray(elem string, count int) []string { func makeArray(elem string, count int) []string {
res := make([]string, count) res := make([]string, count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
@ -166,10 +145,6 @@ func makeArray(elem string, count int) []string {
return res return res
} }
func rValue(bean interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(bean))
}
func rType(bean interface{}) reflect.Type { func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
// return reflect.TypeOf(sliceValue.Interface()) // return reflect.TypeOf(sliceValue.Interface())
@ -183,10 +158,6 @@ func structName(v reflect.Type) string {
return v.Name() return v.Name()
} }
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}
func formatTime(t time.Time) string { func formatTime(t time.Time) string {
return t.Format("2006-01-02 15:04:05") return t.Format("2006-01-02 15:04:05")
} }

View File

@ -7,7 +7,6 @@ package xorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"reflect" "reflect"
"time" "time"
@ -113,7 +112,7 @@ type EngineInterface interface {
Sync(...interface{}) error Sync(...interface{}) error
Sync2(...interface{}) error Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table TableInfo(bean interface{}) (*Table, error)
TableName(interface{}, ...bool) string TableName(interface{}, ...bool) string
UnMapType(reflect.Type) UnMapType(reflect.Type)
} }
@ -123,27 +122,3 @@ var (
_ EngineInterface = &Engine{} _ EngineInterface = &Engine{}
_ EngineInterface = &EngineGroup{} _ EngineInterface = &EngineGroup{}
) )
// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)
// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}
// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

31
internal/json/json.go Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "encoding/json"
// JSONInterface represents an interface to handle json data
type JSONInterface interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
var (
// DefaultJSONHandler default json handler
DefaultJSONHandler JSONInterface = StdJSON{}
)
// StdJSON implements JSONInterface via encoding/json
type StdJSON struct{}
// Marshal implements JSONInterface
func (StdJSON) Marshal(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
// Unmarshal implements JSONInterface
func (StdJSON) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

View File

@ -0,0 +1,79 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"fmt"
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
func (statement *Statement) ConvertIDSQL(sqlStr string) string {
if statement.RefTable != nil {
cols := statement.RefTable.PKColumns()
if len(cols) == 0 {
return ""
}
colstrs := statement.joinColumns(cols, false)
sqls := utils.SplitNNoCase(sqlStr, " from ", 2)
if len(sqls) != 2 {
return ""
}
var top string
pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.DBType() == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
return newsql
}
return ""
}
func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
return "", ""
}
colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
sqls := utils.SplitNNoCase(sqlStr, "where", 2)
if len(sqls) != 2 {
if len(sqls) == 1 {
return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
colstrs, statement.quote(statement.TableName()))
}
return "", ""
}
var whereStr = sqls[1]
// TODO: for postgres only, if any other database?
var paraStr string
if statement.dialect.DBType() == schemas.POSTGRES {
paraStr = "$"
} else if statement.dialect.DBType() == schemas.MSSQL {
paraStr = ":"
}
if paraStr != "" {
if strings.Contains(sqls[1], paraStr) {
dollers := strings.Split(sqls[1], paraStr)
whereStr = dollers[0]
for i, c := range dollers[1:] {
ccs := strings.SplitN(c, " ", 2)
whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
}
}
}
return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
colstrs, statement.quote(statement.TableName()),
whereStr)
}

View File

@ -2,13 +2,17 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package statements
import "strings" import (
"strings"
"xorm.io/xorm/schemas"
)
type columnMap []string type columnMap []string
func (m columnMap) contain(colName string) bool { func (m columnMap) Contain(colName string) bool {
if len(m) == 0 { if len(m) == 0 {
return false return false
} }
@ -27,9 +31,28 @@ func (m columnMap) contain(colName string) bool {
} }
func (m *columnMap) add(colName string) bool { func (m *columnMap) add(colName string) bool {
if m.contain(colName) { if m.Contain(colName) {
return false return false
} }
*m = append(*m, colName) *m = append(*m, colName)
return true return true
} }
func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has bool) {
if len(m) == 0 {
return false, false
}
n := len(col.Name)
for mk := range m {
if len(mk) != n {
continue
}
if strings.EqualFold(mk, col.Name) {
return m[mk], true
}
}
return false, false
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package statements
import ( import (
"fmt" "fmt"
@ -26,21 +26,21 @@ type exprParam struct {
} }
type exprParams struct { type exprParams struct {
colNames []string ColNames []string
args []interface{} Args []interface{}
} }
func (exprs *exprParams) Len() int { func (exprs *exprParams) Len() int {
return len(exprs.colNames) return len(exprs.ColNames)
} }
func (exprs *exprParams) addParam(colName string, arg interface{}) { func (exprs *exprParams) addParam(colName string, arg interface{}) {
exprs.colNames = append(exprs.colNames, colName) exprs.ColNames = append(exprs.ColNames, colName)
exprs.args = append(exprs.args, arg) exprs.Args = append(exprs.Args, arg)
} }
func (exprs *exprParams) isColExist(colName string) bool { func (exprs *exprParams) IsColExist(colName string) bool {
for _, name := range exprs.colNames { for _, name := range exprs.ColNames {
if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) { if strings.EqualFold(schemas.CommonQuoter.Trim(name), schemas.CommonQuoter.Trim(colName)) {
return true return true
} }
@ -49,16 +49,16 @@ func (exprs *exprParams) isColExist(colName string) bool {
} }
func (exprs *exprParams) getByName(colName string) (exprParam, bool) { func (exprs *exprParams) getByName(colName string) (exprParam, bool) {
for i, name := range exprs.colNames { for i, name := range exprs.ColNames {
if strings.EqualFold(name, colName) { if strings.EqualFold(name, colName) {
return exprParam{name, exprs.args[i]}, true return exprParam{name, exprs.Args[i]}, true
} }
} }
return exprParam{}, false return exprParam{}, false
} }
func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error { func (exprs *exprParams) WriteArgs(w *builder.BytesWriter) error {
for i, expr := range exprs.args { for i, expr := range exprs.Args {
switch arg := expr.(type) { switch arg := expr.(type) {
case *builder.Builder: case *builder.Builder:
if _, err := w.WriteString("("); err != nil { if _, err := w.WriteString("("); err != nil {
@ -83,7 +83,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
} }
w.Append(arg) w.Append(arg)
} }
if i != len(exprs.args)-1 { if i != len(exprs.Args)-1 {
if _, err := w.WriteString(","); err != nil { if _, err := w.WriteString(","); err != nil {
return err return err
} }
@ -93,7 +93,7 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
} }
func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error { func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
for i, colName := range exprs.colNames { for i, colName := range exprs.ColNames {
if _, err := w.WriteString(colName); err != nil { if _, err := w.WriteString(colName); err != nil {
return err return err
} }
@ -101,7 +101,7 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
return err return err
} }
switch arg := exprs.args[i].(type) { switch arg := exprs.Args[i].(type) {
case *builder.Builder: case *builder.Builder:
if _, err := w.WriteString("("); err != nil { if _, err := w.WriteString("("); err != nil {
return err return err
@ -113,10 +113,10 @@ func (exprs *exprParams) writeNameArgs(w *builder.BytesWriter) error {
return err return err
} }
default: default:
w.Append(exprs.args[i]) w.Append(exprs.Args[i])
} }
if i+1 != len(exprs.colNames) { if i+1 != len(exprs.ColNames) {
if _, err := w.WriteString(","); err != nil { if _, err := w.WriteString(","); err != nil {
return err return err
} }

View File

@ -0,0 +1,448 @@
// Copyright 2019 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"errors"
"fmt"
"reflect"
"strings"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
if len(sqlOrArgs) > 0 {
return ConvertSQLOrArgs(sqlOrArgs...)
}
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := statement.ProcessIDParam(); err != nil {
return "", nil, err
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
args := append(statement.joinArgs, condArgs...)
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
statement.SetRefBean(bean)
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
colName = statement.quote(colName)
}
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
sumSelect := strings.Join(sumStrs, ", ")
condSQL, condArgs, err := statement.GenConds(bean)
if err != nil {
return "", nil, err
}
sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean)
isStruct := v.Kind() == reflect.Struct
if isStruct {
statement.SetRefBean(bean)
}
var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
// TODO: always generate column names, not use * even if join
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
}
}
}
}
if len(columnStr) == 0 {
columnStr = "*"
}
if isStruct {
if err := statement.mergeConds(bean); err != nil {
return "", nil, err
}
} else {
if err := statement.ProcessIDParam(); err != nil {
return "", nil, err
}
}
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
var condSQL string
var condArgs []interface{}
var err error
if len(beans) > 0 {
statement.SetRefBean(beans[0])
condSQL, condArgs, err = statement.GenConds(beans[0])
} else {
condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil {
return "", nil, err
}
var selectSQL = statement.SelectStr
if len(selectSQL) <= 0 {
if statement.IsDistinct {
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
} else {
selectSQL = "count(*)"
}
}
sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
var (
distinct string
dialect = statement.dialect
quote = statement.quote
fromStr = " FROM "
top, mssqlCondi, whereStr string
)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
}
if len(condSQL) > 0 {
whereStr = " WHERE " + condSQL
}
if dialect.DBType() == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
fromStr += statement.TableName()
} else {
fromStr += quote(statement.TableName())
}
if statement.TableAlias != "" {
if dialect.DBType() == schemas.ORACLE {
fromStr += " " + quote(statement.TableAlias)
} else {
fromStr += " AS " + quote(statement.TableAlias)
}
}
if statement.JoinStr != "" {
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
}
pLimitN := statement.LimitN
if dialect.DBType() == schemas.MSSQL {
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
}
if statement.Start > 0 {
var column string
if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes {
if len(index.Cols) == 1 {
column = index.Cols[0]
break
}
}
if len(column) == 0 {
column = statement.RefTable.ColumnsSeq()[0]
}
} else {
column = statement.RefTable.PKColumns()[0].Name
}
if statement.needTableName() {
if len(statement.TableAlias) > 0 {
column = statement.TableAlias + "." + column
} else {
column = statement.TableName() + "." + column
}
}
var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 {
orderStr = " ORDER BY " + statement.OrderStr
}
var groupStr string
if len(statement.GroupByStr) > 0 {
groupStr = " GROUP BY " + statement.GroupByStr
}
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
}
}
var buf strings.Builder
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
if len(mssqlCondi) > 0 {
if len(whereStr) > 0 {
fmt.Fprint(&buf, " AND ", mssqlCondi)
} else {
fmt.Fprint(&buf, " WHERE ", mssqlCondi)
}
}
if statement.GroupByStr != "" {
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
}
if statement.HavingStr != "" {
fmt.Fprint(&buf, " ", statement.HavingStr)
}
if needOrderBy && statement.OrderStr != "" {
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
}
if needLimit {
if dialect.DBType() != schemas.MSSQL && dialect.DBType() != schemas.ORACLE {
if statement.Start > 0 {
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == schemas.ORACLE {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
rawColStr := columnStr
if rawColStr == "*" {
rawColStr = "at.*"
}
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
}
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), nil
}
return buf.String(), nil
}
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
var sqlStr string
var args []interface{}
var joinStr string
var err error
if len(bean) == 0 {
tableName := statement.TableName()
if len(tableName) <= 0 {
return "", nil, ErrTableNotFound
}
tableName = statement.quote(tableName)
if len(statement.JoinStr) > 0 {
joinStr = statement.JoinStr
}
if statement.Conds().IsValid() {
condSQL, condArgs, err := builder.ToSQL(statement.Conds())
if err != nil {
return "", nil, err
}
if statement.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if statement.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
}
args = condArgs
} else {
if statement.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if statement.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)
}
args = []interface{}{}
}
} else {
beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return "", nil, errors.New("needs a pointer")
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := statement.SetRefBean(bean[0]); err != nil {
return "", nil, err
}
}
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
statement.Limit(1)
sqlStr, args, err = statement.GenGetSQL(bean[0])
if err != nil {
return "", nil, err
}
}
return sqlStr, args, nil
}
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
if statement.RawSQL != "" {
return statement.RawSQL, statement.RawParams, nil
}
var sqlStr string
var args []interface{}
var err error
if len(statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
statement.cond = statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(statement.cond)
if err != nil {
return "", nil, err
}
args = append(statement.joinArgs, condArgs...)
sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package statements
import ( import (
"fmt" "fmt"
@ -77,7 +77,7 @@ func convertArg(arg interface{}, convertFunc func(string) string) string {
const insertSelectPlaceHolder = true const insertSelectPlaceHolder = true
func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) error { func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) { switch argv := arg.(type) {
case bool: case bool:
if statement.dialect.DBType() == schemas.MSSQL { if statement.dialect.DBType() == schemas.MSSQL {
@ -130,9 +130,9 @@ func (statement *Statement) writeArg(w *builder.BytesWriter, arg interface{}) er
return nil return nil
} }
func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}) error { func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error {
for i, arg := range args { for i, arg := range args {
if err := statement.writeArg(w, arg); err != nil { if err := statement.WriteArg(w, arg); err != nil {
return err return err
} }
@ -144,27 +144,3 @@ func (statement *Statement) writeArgs(w *builder.BytesWriter, args []interface{}
} }
return nil return nil
} }
func writeStrings(w *builder.BytesWriter, cols []string, leftQuote, rightQuote string) error {
for i, colName := range cols {
if len(leftQuote) > 0 && colName[0] != '`' {
if _, err := w.WriteString(leftQuote); err != nil {
return err
}
}
if _, err := w.WriteString(colName); err != nil {
return err
}
if len(rightQuote) > 0 && colName[len(colName)-1] != '`' {
if _, err := w.WriteString(rightQuote); err != nil {
return err
}
}
if i+1 != len(cols) {
if _, err := w.WriteString(","); err != nil {
return err
}
}
}
return nil
}

View File

@ -0,0 +1,184 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"reflect"
"strings"
"testing"
"xorm.io/xorm/schemas"
)
var colStrTests = []struct {
omitColumn string
onlyToDBColumnNdx int
expected string
}{
{"", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"Code2", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"Code3", 1, "`ID`, `Caption`, `Code1`, `Code2`, `ParentID`, `Latitude`, `Longitude`"},
{"Longitude", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"},
{"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"},
}
func TestColumnsStringGeneration(t *testing.T) {
if dbType == "postgres" || dbType == "mssql" {
return
}
var statement *Statement
for ndx, testCase := range colStrTests {
statement = createTestStatement()
if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn)
}
columns := statement.RefTable.Columns()
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB
}
actual := statement.genColumnStr()
if actual != testCase.expected {
t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual)
}
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES
}
}
}
func BenchmarkColumnsStringGeneration(b *testing.B) {
b.StopTimer()
statement := createTestStatement()
testCase := colStrTests[0]
if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
}
if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
}
b.StartTimer()
for i := 0; i < b.N; i++ {
actual := statement.genColumnStr()
if actual != testCase.expected {
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
}
}
}
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StopTimer()
mapCols := make(map[string]bool)
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
{Name: `Code1`},
{Name: `Code2`},
{Name: `Code3`},
{Name: `ParentID`},
{Name: `Latitude`},
{Name: `Longitude`},
}
for _, col := range cols {
mapCols[strings.ToLower(col.Name)] = true
}
b.StartTimer()
for i := 0; i < b.N; i++ {
for _, col := range cols {
if _, ok := getFlagForColumn(mapCols, col); !ok {
b.Fatal("Unexpected result")
}
}
}
}
func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
b.StopTimer()
mapCols := make(map[string]bool)
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
{Name: `Code1`},
{Name: `Code2`},
{Name: `Code3`},
{Name: `ParentID`},
{Name: `Latitude`},
{Name: `Longitude`},
}
b.StartTimer()
for i := 0; i < b.N; i++ {
for _, col := range cols {
if _, ok := getFlagForColumn(mapCols, col); ok {
b.Fatal("Unexpected result")
}
}
}
}
type TestType struct {
ID int64 `xorm:"ID PK"`
IsDeleted bool `xorm:"IsDeleted"`
Caption string `xorm:"Caption"`
Code1 string `xorm:"Code1"`
Code2 string `xorm:"Code2"`
Code3 string `xorm:"Code3"`
ParentID int64 `xorm:"ParentID"`
Latitude float64 `xorm:"Latitude"`
Longitude float64 `xorm:"Longitude"`
}
func (TestType) TableName() string {
return "TestTable"
}
func createTestStatement() *Statement {
if engine, ok := testEngine.(*Engine); ok {
statement := &Statement{}
statement.Reset()
statement.Engine = engine
statement.dialect = engine.dialect
statement.SetRefValue(reflect.ValueOf(TestType{}))
return statement
} else if eg, ok := testEngine.(*EngineGroup); ok {
statement := &Statement{}
statement.Reset()
statement.Engine = eg.Engine
statement.dialect = eg.Engine.dialect
statement.SetRefValue(reflect.ValueOf(TestType{}))
return statement
}
return nil
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package xorm package statements
import ( import (
"reflect" "reflect"

View File

@ -0,0 +1,280 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package statements
import (
"database/sql/driver"
"fmt"
"reflect"
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// BuildUpdates auto generating update columnes and values according a struct
func (statement *Statement) BuildUpdates(bean interface{},
includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) ([]string, []interface{}, error) {
//engine := statement.Engine
table := statement.RefTable
allUseBool := statement.allUseBool
useAllCols := statement.useAllCols
mustColumnMap := statement.MustColumnMap
nullableMap := statement.NullableMap
columnMap := statement.ColumnMap
omitColumnMap := statement.OmitColumnMap
unscoped := statement.unscoped
var colNames = make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns() {
if !includeVersion && col.IsVersion {
continue
}
if col.IsCreated && !columnMap.Contain(col.Name) {
continue
}
if !includeUpdated && col.IsUpdated {
continue
}
if !includeAutoIncr && col.IsAutoIncrement {
continue
}
if col.IsDeleted && !unscoped {
continue
}
if omitColumnMap.Contain(col.Name) {
continue
}
if len(columnMap) > 0 && !columnMap.Contain(col.Name) {
continue
}
if col.MapType == schemas.ONLYFROMDB {
continue
}
if statement.IncrColumns.IsColExist(col.Name) {
continue
} else if statement.DecrColumns.IsColExist(col.Name) {
continue
} else if statement.ExprColumns.IsColExist(col.Name) {
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface())
if fieldType == nil {
continue
}
requiredField := useAllCols
includeNil := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b {
requiredField = true
} else {
continue
}
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if b, ok := getFlagForColumn(nullableMap, col); ok {
if b && col.Nullable && utils.IsZero(fieldValue.Interface()) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
fieldType = reflect.TypeOf(fieldValue.Interface())
includeNil = true
}
}
var val interface{}
if fieldValue.CanAddr() {
if structConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
data, err := structConvert.ToDB()
if err != nil {
return nil, nil, err
}
val = data
goto APPEND
}
}
if structConvert, ok := fieldValue.Interface().(convert.Conversion); ok {
data, err := structConvert.ToDB()
if err != nil {
return nil, nil, err
}
val = data
goto APPEND
}
if fieldType.Kind() == reflect.Ptr {
if fieldValue.IsNil() {
if includeNil {
args = append(args, nil)
colNames = append(colNames, fmt.Sprintf("%v=?", statement.quote(col.Name)))
}
continue
} else if !fieldValue.IsValid() {
continue
} else {
// dereference ptr type to instance type
fieldValue = fieldValue.Elem()
fieldType = reflect.TypeOf(fieldValue.Interface())
requiredField = true
}
}
switch fieldType.Kind() {
case reflect.Bool:
if allUseBool || requiredField {
val = fieldValue.Interface()
} else {
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
continue
}
case reflect.String:
if !requiredField && fieldValue.String() == "" {
continue
}
// for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() {
val = fieldValue.String()
} else {
val = fieldValue.Interface()
}
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if !requiredField && fieldValue.Int() == 0 {
continue
}
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64:
if !requiredField && fieldValue.Float() == 0.0 {
continue
}
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if !requiredField && fieldValue.Uint() == 0 {
continue
}
t := int64(fieldValue.Uint())
val = reflect.ValueOf(&t).Interface()
case reflect.Struct:
if fieldType.ConvertibleTo(schemas.TimeType) {
t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time)
if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
continue
}
val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t)
} else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = nulType.Value()
} else {
if !col.SQLType.IsJson() {
table, err := statement.tagParser.MapType(fieldValue)
if err != nil {
val = fieldValue.Interface()
} else {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
if pkField.IsValid() && (!requiredField && !utils.IsZero(pkField.Interface())) {
val = pkField.Interface()
} else {
continue
}
} else {
// TODO: how to handler?
panic("not supported")
}
}
} else {
// Blank struct could not be as update data
if requiredField || !utils.IsStructZero(fieldValue) {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
}
if col.SQLType.IsText() {
val = string(bytes)
} else if col.SQLType.IsBlob() {
val = bytes
}
} else {
continue
}
}
}
case reflect.Array, reflect.Slice, reflect.Map:
if !requiredField {
if fieldValue == reflect.Zero(fieldType) {
continue
}
if fieldType.Kind() == reflect.Array {
if utils.IsArrayZero(fieldValue) {
continue
}
} else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
continue
}
}
if col.SQLType.IsText() {
bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, nil, err
}
val = string(bytes)
} else if col.SQLType.IsBlob() {
var bytes []byte
var err error
if fieldType.Kind() == reflect.Slice &&
fieldType.Elem().Kind() == reflect.Uint8 {
if fieldValue.Len() > 0 {
val = fieldValue.Bytes()
} else {
continue
}
} else if fieldType.Kind() == reflect.Array &&
fieldType.Elem().Kind() == reflect.Uint8 {
val = fieldValue.Slice(0, 0).Interface()
} else {
bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil {
return nil, nil, err
}
val = bytes
}
} else {
continue
}
default:
val = fieldValue.Interface()
}
APPEND:
args = append(args, val)
if col.IsPrimaryKey {
continue
}
colNames = append(colNames, fmt.Sprintf("%v = ?", statement.quote(col.Name)))
}
return colNames, args, nil
}

13
internal/utils/name.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"fmt"
)
func IndexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}

13
internal/utils/reflect.go Normal file
View File

@ -0,0 +1,13 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"reflect"
)
func ReflectValue(bean interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(bean))
}

19
internal/utils/sql.go Normal file
View File

@ -0,0 +1,19 @@
// Copyright 2020 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"strings"
)
func IsSubQuery(tbName string) bool {
const selStr = "select"
if len(tbName) <= len(selStr)+1 {
return false
}
return strings.EqualFold(tbName[:len(selStr)], selStr) ||
strings.EqualFold(tbName[:len(selStr)+1], "("+selStr)
}

30
internal/utils/strings.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2017 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package utils
import (
"strings"
)
func IndexNoCase(s, sep string) int {
return strings.Index(strings.ToLower(s), strings.ToLower(sep))
}
func SplitNoCase(s, sep string) []string {
idx := IndexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.Split(s, s[idx:idx+len(sep)])
}
func SplitNNoCase(s, sep string, n int) []string {
idx := IndexNoCase(s, sep)
if idx < 0 {
return []string{s}
}
return strings.SplitN(s, s[idx:idx+len(sep)], n)
}

View File

@ -10,6 +10,7 @@ import (
"reflect" "reflect"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/internal/utils"
) )
// Rows rows wrapper a rows to // Rows rows wrapper a rows to
@ -29,7 +30,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
var args []interface{} var args []interface{}
var err error var err error
if err = rows.session.statement.setRefBean(bean); err != nil { if err = rows.session.statement.SetRefBean(bean); err != nil {
return nil, err return nil, err
} }
@ -38,7 +39,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
} }
if rows.session.statement.RawSQL == "" { if rows.session.statement.RawSQL == "" {
sqlStr, args, err = rows.session.statement.genGetSQL(bean) sqlStr, args, err = rows.session.statement.GenGetSQL(bean)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -84,7 +85,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType)
} }
if err := rows.session.statement.setRefBean(bean); err != nil { if err := rows.session.statement.SetRefBean(bean); err != nil {
return err return err
} }
@ -98,7 +99,7 @@ func (rows *Rows) Scan(bean interface{}) error {
return err return err
} }
dataStruct := rValue(bean) dataStruct := utils.ReflectValue(bean)
_, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable)
if err != nil { if err != nil {
return err return err

View File

@ -109,6 +109,40 @@ func (q Quoter) Join(a []string, sep string) string {
return b.String() return b.String()
} }
func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error {
if len(a) == 0 {
return nil
}
n := len(sep) * (len(a) - 1)
for i := 0; i < len(a); i++ {
n += len(a[i])
}
b.Grow(n)
for i, s := range a {
if i > 0 {
if _, err := b.WriteString(sep); err != nil {
return err
}
}
if q[0] != "" && s != "*" && s[0] != '`' {
if _, err := b.WriteString(q[0]); err != nil {
return err
}
}
if _, err := b.WriteString(strings.TrimSpace(s)); err != nil {
return err
}
if q[1] != "" && s != "*" && s[0] != '`' {
if _, err := b.WriteString(q[1]); err != nil {
return err
}
}
}
return nil
}
func (q Quoter) Strings(s []string) []string { func (q Quoter) Strings(s []string) []string {
var res = make([]string, 0, len(s)) var res = make([]string, 0, len(s))
for _, a := range s { for _, a := range s {

View File

@ -14,8 +14,11 @@ import (
"strings" "strings"
"time" "time"
"xorm.io/xorm/contexts"
"xorm.io/xorm/convert" "xorm.io/xorm/convert"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/statements"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -32,7 +35,7 @@ type Session struct {
db *core.DB db *core.DB
engine *Engine engine *Engine
tx *core.Tx tx *core.Tx
statement Statement statement *statements.Statement
isAutoCommit bool isAutoCommit bool
isCommitedOrRollbacked bool isCommitedOrRollbacked bool
isAutoClose bool isAutoClose bool
@ -73,9 +76,12 @@ func (session *Session) Clone() *Session {
// Init reset the session as the init status. // Init reset the session as the init status.
func (session *Session) Init() { func (session *Session) Init() {
session.statement.Reset() session.statement = statements.NewStatement(
session.statement.dialect = session.engine.dialect session.engine.dialect,
session.statement.Engine = session.engine session.engine.tagParser,
session.engine.DatabaseTZ,
)
session.showSQL = session.engine.showSQL session.showSQL = session.engine.showSQL
session.isAutoCommit = true session.isAutoCommit = true
session.isCommitedOrRollbacked = false session.isCommitedOrRollbacked = false
@ -118,8 +124,8 @@ func (session *Session) Close() {
} }
// ContextCache enable context cache or not // ContextCache enable context cache or not
func (session *Session) ContextCache(context ContextCache) *Session { func (session *Session) ContextCache(context contexts.ContextCache) *Session {
session.statement.context = context session.statement.SetContextCache(context)
return session return session
} }
@ -158,7 +164,9 @@ func (session *Session) After(closures func(interface{})) *Session {
// Table can input a string or pointer to struct for special a table to operate. // Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session { func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.statement.Table(tableNameOrBean) if err := session.statement.SetTable(tableNameOrBean); err != nil {
session.engine.logger.Error(err)
}
return session return session
} }
@ -182,7 +190,7 @@ func (session *Session) ForUpdate() *Session {
// NoAutoCondition disable generate SQL condition from beans // NoAutoCondition disable generate SQL condition from beans
func (session *Session) NoAutoCondition(no ...bool) *Session { func (session *Session) NoAutoCondition(no ...bool) *Session {
session.statement.NoAutoCondition(no...) session.statement.SetNoAutoCondition(no...)
return session return session
} }
@ -288,7 +296,7 @@ func (session *Session) canCache() bool {
!session.statement.UseCache || !session.statement.UseCache ||
session.statement.IsForUpdate || session.statement.IsForUpdate ||
session.tx != nil || session.tx != nil ||
len(session.statement.selectStr) > 0 { len(session.statement.SelectStr) > 0 {
return false return false
} }
return true return true
@ -505,13 +513,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
continue continue
} }
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -535,13 +543,13 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
if len(bs) > 0 { if len(bs) > 0 {
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
err := DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface()) err := json.DefaultJSONHandler.Unmarshal(bs, fieldValue.Addr().Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := DefaultJSONHandler.Unmarshal(bs, x.Interface()) err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -557,7 +565,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
if col.SQLType.IsText() { if col.SQLType.IsText() {
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -672,7 +680,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface()) err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -682,7 +690,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
hasAssigned = true hasAssigned = true
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(vv.Bytes()) > 0 { if len(vv.Bytes()) > 0 {
err := DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface()) err := json.DefaultJSONHandler.Unmarshal(vv.Bytes(), x.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -818,7 +826,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
case schemas.Complex64Type: case schemas.Complex64Type:
var x complex64 var x complex64
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -828,7 +836,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
case schemas.Complex128Type: case schemas.Complex128Type:
var x complex128 var x complex128
if len([]byte(vv.String())) > 0 { if len([]byte(vv.String())) > 0 {
err := DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x) err := json.DefaultJSONHandler.Unmarshal([]byte(vv.String()), &x)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -877,7 +885,7 @@ func (session *Session) LastSQL() (string, []interface{}) {
// Unscoped always disable struct tag "deleted" // Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session { func (session *Session) Unscoped() *Session {
session.statement.Unscoped() session.statement.SetUnscoped()
return session return session
} }

View File

@ -63,19 +63,6 @@ func getFlagForColumn(m map[string]bool, col *schemas.Column) (val bool, has boo
return false, false return false, false
} }
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0, len(columns))
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
// Incr provides a query string like "count = count + 1" // Incr provides a query string like "count = count + 1"
func (session *Session) Incr(column string, arg ...interface{}) *Session { func (session *Session) Incr(column string, arg ...interface{}) *Session {
session.statement.Incr(column, arg...) session.statement.Incr(column, arg...)

View File

@ -51,5 +51,5 @@ func (session *Session) NotIn(column string, args ...interface{}) *Session {
// Conds returns session query conditions except auto bean conditions // Conds returns session query conditions except auto bean conditions
func (session *Session) Conds() builder.Cond { func (session *Session) Conds() builder.Cond {
return session.statement.cond return session.statement.Conds()
} }

View File

@ -15,6 +15,7 @@ import (
"time" "time"
"xorm.io/xorm/convert" "xorm.io/xorm/convert"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -108,7 +109,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, x.Interface()) err := json.DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -122,7 +123,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
if col.SQLType.IsText() { if col.SQLType.IsText() {
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, x.Interface()) err := json.DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -135,7 +136,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
} else { } else {
x := reflect.New(fieldType) x := reflect.New(fieldType)
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, x.Interface()) err := json.DefaultJSONHandler.Unmarshal(data, x.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -264,7 +265,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
case schemas.Complex64Type.Kind(): case schemas.Complex64Type.Kind():
var x complex64 var x complex64
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, &x) err := json.DefaultJSONHandler.Unmarshal(data, &x)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -275,7 +276,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val
case schemas.Complex128Type.Kind(): case schemas.Complex128Type.Kind():
var x complex128 var x complex128
if len(data) > 0 { if len(data) > 0 {
err := DefaultJSONHandler.Unmarshal(data, &x) err := json.DefaultJSONHandler.Unmarshal(data, &x)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return err return err
@ -615,14 +616,14 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
} }
return string(bytes), nil return string(bytes), nil
} else if col.SQLType.IsBlob() { } else if col.SQLType.IsBlob() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
@ -631,7 +632,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
} }
return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type())
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
@ -643,7 +644,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface()) bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err
@ -656,7 +657,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect.
(fieldValue.Type().Elem().Kind() == reflect.Uint8) { (fieldValue.Type().Elem().Kind() == reflect.Uint8) {
bytes = fieldValue.Bytes() bytes = fieldValue.Bytes()
} else { } else {
bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface()) bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
return 0, err return 0, err

View File

@ -23,7 +23,7 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(sqlStr)
} }
newsql := session.statement.convertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
@ -80,11 +80,11 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil { if session.statement.LastError != nil {
return 0, session.statement.lastError return 0, session.statement.LastError
} }
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
} }
@ -98,7 +98,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
processor.BeforeDelete() processor.BeforeDelete()
} }
condSQL, condArgs, err := session.statement.genConds(bean) condSQL, condArgs, err := session.statement.GenConds(bean)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -152,7 +152,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
var realSQL string var realSQL string
argsForCache := make([]interface{}, 0, len(condArgs)*2) argsForCache := make([]interface{}, 0, len(condArgs)*2)
if session.statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled
realSQL = deleteSQL realSQL = deleteSQL
copy(argsForCache, condArgs) copy(argsForCache, condArgs)
argsForCache = append(condArgs, argsForCache...) argsForCache = append(condArgs, argsForCache...)

View File

@ -4,89 +4,19 @@
package xorm package xorm
import (
"errors"
"fmt"
"reflect"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
// Exist returns true if the record exist otherwise return false // Exist returns true if the record exist otherwise return false
func (session *Session) Exist(bean ...interface{}) (bool, error) { func (session *Session) Exist(bean ...interface{}) (bool, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil { if session.statement.LastError != nil {
return false, session.statement.lastError return false, session.statement.LastError
} }
var sqlStr string sqlStr, args, err := session.statement.GenExistSQL(bean...)
var args []interface{} if err != nil {
var joinStr string return false, err
var err error
if session.statement.RawSQL == "" {
if len(bean) == 0 {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return false, ErrTableNotFound
}
tableName = session.statement.quote(tableName)
if len(session.statement.JoinStr) > 0 {
joinStr = session.statement.JoinStr
}
if session.statement.cond.IsValid() {
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return false, err
}
if session.engine.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s WHERE %s", tableName, joinStr, condSQL)
} else if session.engine.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE (%s) %s AND ROWNUM=1", tableName, joinStr, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE %s LIMIT 1", tableName, joinStr, condSQL)
}
args = condArgs
} else {
if session.engine.dialect.DBType() == schemas.MSSQL {
sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s %s", tableName, joinStr)
} else if session.engine.dialect.DBType() == schemas.ORACLE {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s WHERE ROWNUM=1", tableName, joinStr)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s %s LIMIT 1", tableName, joinStr)
}
args = []interface{}{}
}
} else {
beanValue := reflect.ValueOf(bean[0])
if beanValue.Kind() != reflect.Ptr {
return false, errors.New("needs a pointer")
}
if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefBean(bean[0]); err != nil {
return false, err
}
}
if len(session.statement.TableName()) <= 0 {
return false, ErrTableNotFound
}
session.statement.Limit(1)
sqlStr, args, err = session.statement.genGetSQL(bean[0])
if err != nil {
return false, err
}
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
} }
rows, err := session.queryRows(sqlStr, args...) rows, err := session.queryRows(sqlStr, args...)

View File

@ -8,10 +8,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -53,8 +53,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
} }
session.autoResetStatement = true session.autoResetStatement = true
if session.statement.selectStr != "" { if session.statement.SelectStr != "" {
session.statement.selectStr = "" session.statement.SelectStr = ""
} }
if session.statement.OrderStr != "" { if session.statement.OrderStr != "" {
session.statement.OrderStr = "" session.statement.OrderStr = ""
@ -66,8 +66,8 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
defer session.resetStatement() defer session.resetStatement()
if session.statement.lastError != nil { if session.statement.LastError != nil {
return session.statement.lastError return session.statement.LastError
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
@ -82,7 +82,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
if err := session.statement.setRefValue(pv); err != nil { if err := session.statement.SetRefValue(pv); err != nil {
return err return err
} }
} else { } else {
@ -90,7 +90,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType) pv := reflect.New(sliceElementType)
if err := session.statement.setRefValue(pv); err != nil { if err := session.statement.SetRefValue(pv); err != nil {
return err return err
} }
} else { } else {
@ -103,16 +103,16 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
var addedTableName = (len(session.statement.JoinStr) > 0) var addedTableName = (len(session.statement.JoinStr) > 0)
var autoCond builder.Cond var autoCond builder.Cond
if tp == tpStruct { if tp == tpStruct {
if !session.statement.noAutoCondition && len(condiBean) > 0 { if !session.statement.NoAutoCondition && len(condiBean) > 0 {
var err error var err error
autoCond, err = session.statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) autoCond, err = session.statement.BuildConds(table, condiBean[0], true, true, false, true, addedTableName)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given. // !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
// See https://gitea.com/xorm/xorm/issues/179 // See https://gitea.com/xorm/xorm/issues/179
if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
var colName = session.engine.Quote(col.Name) var colName = session.engine.Quote(col.Name)
if addedTableName { if addedTableName {
var nm = session.statement.TableName() var nm = session.statement.TableName()
@ -122,70 +122,20 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
colName = session.engine.Quote(nm) + "." + colName colName = session.engine.Quote(nm) + "." + colName
} }
autoCond = session.engine.CondDeleted(col) autoCond = session.statement.CondDeleted(col)
} }
} }
} }
var sqlStr string sqlStr, args, err := session.statement.GenFindSQL(autoCond)
var args []interface{} if err != nil {
var err error return err
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) <= 0 {
return ErrTableNotFound
}
var columnStr = session.statement.columnStr()
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
session.statement.cond = session.statement.cond.And(autoCond)
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return err
}
args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
} }
if session.canCache() { if session.canCache() {
if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
!session.statement.IsDistinct && !session.statement.IsDistinct &&
!session.statement.unscoped { !session.statement.GetUnscoped() {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return err return err
@ -274,7 +224,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface()) dataStruct := utils.ReflectValue(newValue.Interface())
tb, err := session.engine.tagParser.MapType(dataStruct) tb, err := session.engine.tagParser.MapType(dataStruct)
if err != nil { if err != nil {
return err return err
@ -323,8 +273,8 @@ func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) erro
func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) {
if !session.canCache() || if !session.canCache() ||
indexNoCase(sqlStr, "having") != -1 || utils.IndexNoCase(sqlStr, "having") != -1 ||
indexNoCase(sqlStr, "group by") != -1 { utils.IndexNoCase(sqlStr, "group by") != -1 {
return ErrCacheFailed return ErrCacheFailed
} }
@ -338,7 +288,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(sqlStr)
} }
newsql := session.statement.convertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
) )
@ -299,13 +300,13 @@ func TestHaving(t *testing.T) {
func TestOrderSameMapper(t *testing.T) { func TestOrderSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
mapper := testEngine.GetTableMapper() mapper := testEngine.GetTableMapper()
testEngine.SetMapper(names.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
testEngine.SetMapper(mapper) testEngine.SetMapper(mapper)
}() }()
@ -324,12 +325,12 @@ func TestOrderSameMapper(t *testing.T) {
func TestHavingSameMapper(t *testing.T) { func TestHavingSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
mapper := testEngine.GetTableMapper() mapper := testEngine.GetTableMapper()
testEngine.SetMapper(names.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
testEngine.SetMapper(mapper) testEngine.SetMapper(mapper)
}() }()
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))

View File

@ -12,6 +12,7 @@ import (
"strconv" "strconv"
"xorm.io/xorm/caches" "xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -27,8 +28,8 @@ func (session *Session) Get(bean interface{}) (bool, error) {
func (session *Session) get(bean interface{}) (bool, error) { func (session *Session) get(bean interface{}) (bool, error) {
defer session.resetStatement() defer session.resetStatement()
if session.statement.lastError != nil { if session.statement.LastError != nil {
return false, session.statement.lastError return false, session.statement.LastError
} }
beanValue := reflect.ValueOf(bean) beanValue := reflect.ValueOf(bean)
@ -39,7 +40,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
} }
if beanValue.Elem().Kind() == reflect.Struct { if beanValue.Elem().Kind() == reflect.Struct {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return false, err return false, err
} }
} }
@ -53,7 +54,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
return false, ErrTableNotFound return false, ErrTableNotFound
} }
session.statement.Limit(1) session.statement.Limit(1)
sqlStr, args, err = session.statement.genGetSQL(bean) sqlStr, args, err = session.statement.GenGetSQL(bean)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -66,7 +67,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
if session.canCache() && beanValue.Elem().Kind() == reflect.Struct { if session.canCache() && beanValue.Elem().Kind() == reflect.Struct {
if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
!session.statement.unscoped { !session.statement.GetUnscoped() {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return has, err return has, err
@ -74,7 +75,7 @@ func (session *Session) get(bean interface{}) (bool, error) {
} }
} }
context := session.statement.context context := session.statement.Context
if context != nil { if context != nil {
res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args))
if res != nil { if res != nil {
@ -244,7 +245,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table,
// close it before covert data // close it before covert data
rows.Close() rows.Close()
dataStruct := rValue(bean) dataStruct := utils.ReflectValue(bean)
_, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table)
if err != nil { if err != nil {
return true, err return true, err
@ -274,7 +275,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
for _, filter := range session.engine.dialect.Filters() { for _, filter := range session.engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr) sqlStr = filter.Do(sqlStr)
} }
newsql := session.statement.convertIDSQL(sqlStr) newsql := session.statement.ConvertIDSQL(sqlStr)
if newsql == "" { if newsql == "" {
return false, ErrCacheFailed return false, ErrCacheFailed
} }

View File

@ -11,6 +11,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm/contexts"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -417,7 +418,7 @@ func TestContextGet(t *testing.T) {
sess := testEngine.NewSession() sess := testEngine.NewSession()
defer sess.Close() defer sess.Close()
context := NewMemoryContextCache() context := contexts.NewMemoryContextCache()
var c2 ContextGetStruct var c2 ContextGetStruct
has, err := sess.ID(1).NoCache().ContextCache(context).Get(&c2) has, err := sess.ID(1).NoCache().ContextCache(context).Get(&c2)
@ -452,7 +453,7 @@ func TestContextGet2(t *testing.T) {
_, err := testEngine.Insert(&ContextGetStruct2{Name: "1"}) _, err := testEngine.Insert(&ContextGetStruct2{Name: "1"})
assert.NoError(t, err) assert.NoError(t, err)
context := NewMemoryContextCache() context := contexts.NewMemoryContextCache()
var c2 ContextGetStruct2 var c2 ContextGetStruct2
has, err := testEngine.ID(1).NoCache().ContextCache(context).Get(&c2) has, err := testEngine.ID(1).NoCache().ContextCache(context).Get(&c2)

View File

@ -113,7 +113,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, errors.New("could not insert a empty slice") return 0, errors.New("could not insert a empty slice")
} }
if err := session.statement.setRefBean(sliceValue.Index(0).Interface()); err != nil { if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err return 0, err
} }
@ -163,10 +163,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.statement.omitColumnMap.contain(col.Name) { if session.statement.OmitColumnMap.Contain(col.Name) {
continue continue
} }
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue continue
} }
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
@ -178,7 +178,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.statement.checkVersion { } else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1) args = append(args, 1)
var colName = col.Name var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -214,10 +214,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsDeleted { if col.IsDeleted {
continue continue
} }
if session.statement.omitColumnMap.contain(col.Name) { if session.statement.OmitColumnMap.Contain(col.Name) {
continue continue
} }
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue continue
} }
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
@ -229,7 +229,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.statement.checkVersion { } else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1) args = append(args, 1)
var colName = col.Name var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) { session.afterClosures = append(session.afterClosures, func(bean interface{}) {
@ -329,7 +329,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
} }
func (session *Session) innerInsert(bean interface{}) (int64, error) { func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
} }
if len(session.statement.TableName()) <= 0 { if len(session.statement.TableName()) <= 0 {
@ -353,7 +353,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
exprs := session.statement.exprColumns exprs := session.statement.ExprColumns
colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces := strings.Repeat("?, ", len(colNames))
if exprs.Len() <= 0 && len(colPlaces) > 0 { if exprs.Len() <= 0 && len(colPlaces) > 0 {
colPlaces = colPlaces[0 : len(colPlaces)-2] colPlaces = colPlaces[0 : len(colPlaces)-2]
@ -385,25 +385,25 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
if err := writeStrings(buf, append(colNames, exprs.colNames...), "`", "`"); err != nil { if err := session.engine.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil {
return 0, err return 0, err
} }
if session.statement.cond.IsValid() { if session.statement.Conds().IsValid() {
if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil { if _, err := buf.WriteString(fmt.Sprintf(")%s SELECT ", output)); err != nil {
return 0, err return 0, err
} }
if err := session.statement.writeArgs(buf, args); err != nil { if err := session.statement.WriteArgs(buf, args); err != nil {
return 0, err return 0, err
} }
if len(exprs.args) > 0 { if len(exprs.Args) > 0 {
if _, err := buf.WriteString(","); err != nil { if _, err := buf.WriteString(","); err != nil {
return 0, err return 0, err
} }
} }
if err := exprs.writeArgs(buf); err != nil { if err := exprs.WriteArgs(buf); err != nil {
return 0, err return 0, err
} }
@ -411,7 +411,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
if err := session.statement.cond.WriteTo(buf); err != nil { if err := session.statement.Conds().WriteTo(buf); err != nil {
return 0, err return 0, err
} }
} else { } else {
@ -423,7 +423,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 0, err return 0, err
} }
if err := exprs.writeArgs(buf); err != nil { if err := exprs.WriteArgs(buf); err != nil {
return 0, err return 0, err
} }
@ -482,7 +482,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.cacheInsert(tableName) session.cacheInsert(tableName)
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
@ -523,7 +523,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.cacheInsert(tableName) session.cacheInsert(tableName)
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
@ -564,7 +564,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.cacheInsert(tableName) session.cacheInsert(tableName)
if table.Version != "" && session.statement.checkVersion { if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil { if err != nil {
session.engine.logger.Error(err) session.engine.logger.Error(err)
@ -637,19 +637,19 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
continue continue
} }
if session.statement.omitColumnMap.contain(col.Name) { if session.statement.OmitColumnMap.Contain(col.Name) {
continue continue
} }
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue continue
} }
if session.statement.incrColumns.isColExist(col.Name) { if session.statement.IncrColumns.IsColExist(col.Name) {
continue continue
} else if session.statement.decrColumns.isColExist(col.Name) { } else if session.statement.DecrColumns.IsColExist(col.Name) {
continue continue
} else if session.statement.exprColumns.isColExist(col.Name) { } else if session.statement.ExprColumns.IsColExist(col.Name) {
continue continue
} }
@ -681,7 +681,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
} }
// !evalphobia! set fieldValue as nil when column is nullable and zero-value // !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
if col.Nullable && utils.IsValueZero(fieldValue) { if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int var nilValue *int
fieldValue = reflect.ValueOf(nilValue) fieldValue = reflect.ValueOf(nilValue)
@ -698,7 +698,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.statement.checkVersion { } else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1) args = append(args, 1)
} else { } else {
arg, err := session.value2Interface(col, fieldValue) arg, err := session.value2Interface(col, fieldValue)
@ -724,9 +724,9 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
} }
var columns = make([]string, 0, len(m)) var columns = make([]string, 0, len(m))
exprs := session.statement.exprColumns exprs := session.statement.ExprColumns
for k := range m { for k := range m {
if !exprs.isColExist(k) { if !exprs.IsColExist(k) {
columns = append(columns, k) columns = append(columns, k)
} }
} }
@ -751,9 +751,9 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
} }
var columns = make([]string, 0, len(m)) var columns = make([]string, 0, len(m))
exprs := session.statement.exprColumns exprs := session.statement.ExprColumns
for k := range m { for k := range m {
if !exprs.isColExist(k) { if !exprs.IsColExist(k) {
columns = append(columns, k) columns = append(columns, k)
} }
} }
@ -774,15 +774,15 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
exprs := session.statement.exprColumns exprs := session.statement.ExprColumns
w := builder.NewWriter() w := builder.NewWriter()
// if insert where // if insert where
if session.statement.cond.IsValid() { if session.statement.Conds().IsValid() {
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil { if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
return 0, err return 0, err
} }
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return 0, err return 0, err
} }
@ -790,15 +790,15 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
return 0, err return 0, err
} }
if err := session.statement.writeArgs(w, args); err != nil { if err := session.statement.WriteArgs(w, args); err != nil {
return 0, err return 0, err
} }
if len(exprs.args) > 0 { if len(exprs.Args) > 0 {
if _, err := w.WriteString(","); err != nil { if _, err := w.WriteString(","); err != nil {
return 0, err return 0, err
} }
if err := exprs.writeArgs(w); err != nil { if err := exprs.WriteArgs(w); err != nil {
return 0, err return 0, err
} }
} }
@ -807,7 +807,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
return 0, err return 0, err
} }
if err := session.statement.cond.WriteTo(w); err != nil { if err := session.statement.Conds().WriteTo(w); err != nil {
return 0, err return 0, err
} }
} else { } else {
@ -818,7 +818,7 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
return 0, err return 0, err
} }
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil { if err := session.engine.dialect.Quoter().JoinWrite(w.Builder, append(columns, exprs.ColNames...), ","); err != nil {
return 0, err return 0, err
} }
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil { if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
@ -826,11 +826,11 @@ func (session *Session) insertMap(columns []string, args []interface{}) (int64,
} }
w.Append(args...) w.Append(args...)
if len(exprs.args) > 0 { if len(exprs.Args) > 0 {
if _, err := w.WriteString(","); err != nil { if _, err := w.WriteString(","); err != nil {
return 0, err return 0, err
} }
if err := exprs.writeArgs(w); err != nil { if err := exprs.WriteArgs(w); err != nil {
return 0, err return 0, err
} }
} }

View File

@ -6,6 +6,8 @@ package xorm
import ( import (
"reflect" "reflect"
"xorm.io/xorm/internal/utils"
) )
// IterFunc only use by Iterate // IterFunc only use by Iterate
@ -25,11 +27,11 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil { if session.statement.LastError != nil {
return session.statement.lastError return session.statement.LastError
} }
if session.statement.bufferSize > 0 { if session.statement.BufferSize > 0 {
return session.bufferIterate(bean, fun) return session.bufferIterate(bean, fun)
} }
@ -57,18 +59,18 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
// BufferSize sets the buffersize for iterate // BufferSize sets the buffersize for iterate
func (session *Session) BufferSize(size int) *Session { func (session *Session) BufferSize(size int) *Session {
session.statement.bufferSize = size session.statement.BufferSize = size
return session return session
} }
func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error { func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
var bufferSize = session.statement.bufferSize var bufferSize = session.statement.BufferSize
var pLimitN = session.statement.LimitN var pLimitN = session.statement.LimitN
if pLimitN != nil && bufferSize > *pLimitN { if pLimitN != nil && bufferSize > *pLimitN {
bufferSize = *pLimitN bufferSize = *pLimitN
} }
var start = session.statement.Start var start = session.statement.Start
v := rValue(bean) v := utils.ReflectValue(bean)
sliceType := reflect.SliceOf(v.Type()) sliceType := reflect.SliceOf(v.Type())
var idx = 0 var idx = 0
session.autoResetStatement = false session.autoResetStatement = false

View File

@ -8,83 +8,19 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"time" "time"
"xorm.io/builder"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
func (session *Session) genQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
if len(sqlOrArgs) > 0 {
return convertSQLOrArgs(sqlOrArgs...)
}
if session.statement.RawSQL != "" {
return session.statement.RawSQL, session.statement.RawParams, nil
}
if len(session.statement.TableName()) <= 0 {
return "", nil, ErrTableNotFound
}
var columnStr = session.statement.columnStr()
if len(session.statement.selectStr) > 0 {
columnStr = session.statement.selectStr
} else {
if session.statement.JoinStr == "" {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr)
} else {
columnStr = session.statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.statement.GroupByStr != "" {
columnStr = session.statement.quoteColumnStr(session.statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := session.statement.processIDParam(); err != nil {
return "", nil, err
}
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
if err != nil {
return "", nil, err
}
args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true, true)
if err != nil {
return "", nil, err
}
// for mssql and use limit
qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs {
args = append(args, args...)
}
return sqlStr, args, nil
}
// Query runs a raw sql and return records as []map[string][]byte // Query runs a raw sql and return records as []map[string][]byte
func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) { func (session *Session) Query(sqlOrArgs ...interface{}) ([]map[string][]byte, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -233,7 +169,7 @@ func (session *Session) QueryString(sqlOrArgs ...interface{}) ([]map[string]stri
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -253,7 +189,7 @@ func (session *Session) QuerySliceString(sqlOrArgs ...interface{}) ([][]string,
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -306,7 +242,7 @@ func (session *Session) QueryInterface(sqlOrArgs ...interface{}) ([]map[string]i
defer session.Close() defer session.Close()
} }
sqlStr, args, err := session.genQuerySQL(sqlOrArgs...) sqlStr, args, err := session.statement.GenQuerySQL(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,8 +9,8 @@ import (
"reflect" "reflect"
"time" "time"
"xorm.io/builder"
"xorm.io/xorm/core" "xorm.io/xorm/core"
"xorm.io/xorm/internal/statements"
) )
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
@ -196,20 +196,6 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
return session.DB().ExecContext(session.ctx, sqlStr, args...) return session.DB().ExecContext(session.ctx, sqlStr, args...)
} }
func convertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
switch sqlOrArgs[0].(type) {
case string:
return sqlOrArgs[0].(string), sqlOrArgs[1:], nil
case *builder.Builder:
return sqlOrArgs[0].(*builder.Builder).ToSQL()
case builder.Builder:
bd := sqlOrArgs[0].(builder.Builder)
return bd.ToSQL()
}
return "", nil, ErrUnSupportedType
}
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) {
if session.isAutoClose { if session.isAutoClose {
@ -220,7 +206,7 @@ func (session *Session) Exec(sqlOrArgs ...interface{}) (sql.Result, error) {
return nil, ErrUnSupportedType return nil, ErrUnSupportedType
} }
sqlStr, args, err := convertSQLOrArgs(sqlOrArgs...) sqlStr, args, err := statements.ConvertSQLOrArgs(sqlOrArgs...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -33,11 +33,11 @@ func (session *Session) CreateTable(bean interface{}) error {
} }
func (session *Session) createTable(bean interface{}) error { func (session *Session) createTable(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqlStr := session.statement.genCreateTableSQL() sqlStr := session.statement.GenCreateTableSQL()
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
@ -52,11 +52,11 @@ func (session *Session) CreateIndexes(bean interface{}) error {
} }
func (session *Session) createIndexes(bean interface{}) error { func (session *Session) createIndexes(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqls := session.statement.genIndexSQL() sqls := session.statement.GenIndexSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
@ -75,11 +75,11 @@ func (session *Session) CreateUniques(bean interface{}) error {
} }
func (session *Session) createUniques(bean interface{}) error { func (session *Session) createUniques(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqls := session.statement.genUniqueSQL() sqls := session.statement.GenUniqueSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
@ -99,11 +99,11 @@ func (session *Session) DropIndexes(bean interface{}) error {
} }
func (session *Session) dropIndexes(bean interface{}) error { func (session *Session) dropIndexes(bean interface{}) error {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return err return err
} }
sqls := session.statement.genDelIndexSQL() sqls := session.statement.GenDelIndexSQL()
for _, sqlStr := range sqls { for _, sqlStr := range sqls {
_, err := session.exec(sqlStr) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
@ -201,7 +201,7 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
func (session *Session) addColumn(colName string) error { func (session *Session) addColumn(colName string) error {
col := session.statement.RefTable.GetColumn(colName) col := session.statement.RefTable.GetColumn(colName)
sql := session.statement.dialect.AddColumnSQL(session.statement.TableName(), col) sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col)
_, err := session.exec(sql) _, err := session.exec(sql)
return err return err
} }
@ -241,7 +241,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}() }()
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := utils.ReflectValue(bean)
table, err := engine.tagParser.MapType(v) table, err := engine.tagParser.MapType(v)
if err != nil { if err != nil {
return err return err
@ -299,7 +299,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// column is not exist on table // column is not exist on table
if oriCol == nil { if oriCol == nil {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil { if err = session.addColumn(col.Name); err != nil {
return err return err
} }
@ -409,11 +409,11 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == schemas.UniqueType { if index.Type == schemas.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType { } else if index.Type == schemas.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name) err = session.addIndex(tbNameWithSchema, name)
} }
if err != nil { if err != nil {

View File

@ -17,17 +17,9 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
defer session.Close() defer session.Close()
} }
var sqlStr string sqlStr, args, err := session.statement.GenCountSQL(bean...)
var args []interface{} if err != nil {
var err error return 0, err
if session.statement.RawSQL == "" {
sqlStr, args, err = session.statement.genCountSQL(bean...)
if err != nil {
return 0, err
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
} }
var total int64 var total int64
@ -50,21 +42,12 @@ func (session *Session) sum(res interface{}, bean interface{}, columnNames ...st
return errors.New("need a pointer to a variable") return errors.New("need a pointer to a variable")
} }
var isSlice = v.Elem().Kind() == reflect.Slice sqlStr, args, err := session.statement.GenSumSQL(bean, columnNames...)
var sqlStr string if err != nil {
var args []interface{} return err
var err error
if len(session.statement.RawSQL) == 0 {
sqlStr, args, err = session.statement.genSumSQL(bean, columnNames...)
if err != nil {
return err
}
} else {
sqlStr = session.statement.RawSQL
args = session.statement.RawParams
} }
if isSlice { if v.Elem().Kind() == reflect.Slice {
err = session.queryRow(sqlStr, args...).ScanSlice(res) err = session.queryRow(sqlStr, args...).ScanSlice(res)
} else { } else {
err = session.queryRow(sqlStr, args...).Scan(res) err = session.queryRow(sqlStr, args...).Scan(res)

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
) )
@ -85,10 +86,10 @@ func TestCombineTransactionSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
testEngine.SetMapper(names.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)
}() }()

View File

@ -23,7 +23,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
return ErrCacheFailed return ErrCacheFailed
} }
oldhead, newsql := session.statement.convertUpdateSQL(sqlStr) oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
@ -88,12 +88,12 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
return err return err
} }
if bean := cacher.GetBean(tableName, sid); bean != nil { if bean := cacher.GetBean(tableName, sid); bean != nil {
sqls := splitNNoCase(sqlStr, "where", 2) sqls := utils.SplitNNoCase(sqlStr, "where", 2)
if len(sqls) == 0 || len(sqls) > 2 { if len(sqls) == 0 || len(sqls) > 2 {
return ErrCacheFailed return ErrCacheFailed
} }
sqls = splitNNoCase(sqls[0], "set", 2) sqls = utils.SplitNNoCase(sqls[0], "set", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
return ErrCacheFailed return ErrCacheFailed
} }
@ -112,7 +112,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
session.engine.logger.Error(err) session.engine.logger.Error(err)
} else { } else {
session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.statement.checkVersion { if col.IsVersion && session.statement.CheckVersion {
session.incrVersionFieldValue(fieldValue) session.incrVersionFieldValue(fieldValue)
} else { } else {
fieldValue.Set(reflect.ValueOf(args[idx])) fieldValue.Set(reflect.ValueOf(args[idx]))
@ -144,11 +144,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
defer session.Close() defer session.Close()
} }
if session.statement.lastError != nil { if session.statement.LastError != nil {
return 0, session.statement.lastError return 0, session.statement.LastError
} }
v := rValue(bean) v := utils.ReflectValue(bean)
t := v.Type() t := v.Type()
var colNames []string var colNames []string
@ -168,7 +168,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var isMap = t.Kind() == reflect.Map var isMap = t.Kind() == reflect.Map
var isStruct = t.Kind() == reflect.Struct var isStruct = t.Kind() == reflect.Struct
if isStruct { if isStruct {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
} }
@ -176,14 +176,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
if session.statement.columnStr() == "" { if session.statement.ColumnStr() == "" {
colNames, args = session.statement.buildUpdates(bean, false, false, colNames, args, err = session.statement.BuildUpdates(bean, false, false,
false, false, true) false, false, true)
} else { } else {
colNames, args, err = session.genUpdateColumns(bean) colNames, args, err = session.genUpdateColumns(bean)
if err != nil { }
return 0, err if err != nil {
} return 0, err
} }
} else if isMap { } else if isMap {
colNames = make([]string, 0) colNames = make([]string, 0)
@ -201,8 +201,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table := session.statement.RefTable table := session.statement.RefTable
if session.statement.UseAutoTime && table != nil && table.Updated != "" { if session.statement.UseAutoTime && table != nil && table.Updated != "" {
if !session.statement.columnMap.contain(table.Updated) && if !session.statement.ColumnMap.Contain(table.Updated) &&
!session.statement.omitColumnMap.contain(table.Updated) { !session.statement.OmitColumnMap.Contain(table.Updated) {
colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?") colNames = append(colNames, session.engine.Quote(table.Updated)+" = ?")
col := table.UpdatedColumn() col := table.UpdatedColumn()
val, t := session.engine.nowTime(col) val, t := session.engine.nowTime(col)
@ -219,21 +219,21 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
// for update action to like "column = column + ?" // for update action to like "column = column + ?"
incColumns := session.statement.incrColumns incColumns := session.statement.IncrColumns
for i, colName := range incColumns.colNames { for i, colName := range incColumns.ColNames {
colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?") colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" + ?")
args = append(args, incColumns.args[i]) args = append(args, incColumns.Args[i])
} }
// for update action to like "column = column - ?" // for update action to like "column = column - ?"
decColumns := session.statement.decrColumns decColumns := session.statement.DecrColumns
for i, colName := range decColumns.colNames { for i, colName := range decColumns.ColNames {
colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?") colNames = append(colNames, session.engine.Quote(colName)+" = "+session.engine.Quote(colName)+" - ?")
args = append(args, decColumns.args[i]) args = append(args, decColumns.Args[i])
} }
// for update action to like "column = expression" // for update action to like "column = expression"
exprColumns := session.statement.exprColumns exprColumns := session.statement.ExprColumns
for i, colName := range exprColumns.colNames { for i, colName := range exprColumns.ColNames {
switch tp := exprColumns.args[i].(type) { switch tp := exprColumns.Args[i].(type) {
case string: case string:
if len(tp) == 0 { if len(tp) == 0 {
tp = "''" tp = "''"
@ -248,16 +248,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, subArgs...) args = append(args, subArgs...)
default: default:
colNames = append(colNames, session.engine.Quote(colName)+"=?") colNames = append(colNames, session.engine.Quote(colName)+"=?")
args = append(args, exprColumns.args[i]) args = append(args, exprColumns.Args[i])
} }
} }
if err = session.statement.processIDParam(); err != nil { if err = session.statement.ProcessIDParam(); err != nil {
return 0, err return 0, err
} }
var autoCond builder.Cond var autoCond builder.Cond
if !session.statement.noAutoCondition { if !session.statement.NoAutoCondition {
condBeanIsStruct := false condBeanIsStruct := false
if len(condiBean) > 0 { if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok { if c, ok := condiBean[0].(map[string]interface{}); ok {
@ -270,7 +270,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
if k == reflect.Struct { if k == reflect.Struct {
var err error var err error
autoCond, err = session.statement.buildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false) autoCond, err = session.statement.BuildConds(session.statement.RefTable, condiBean[0], true, true, false, true, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -282,8 +282,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
if !condBeanIsStruct && table != nil { if !condBeanIsStruct && table != nil {
if col := table.DeletedColumn(); col != nil && !session.statement.unscoped { // tag "deleted" is enabled if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
autoCond1 := session.engine.CondDeleted(col) autoCond1 := session.statement.CondDeleted(col)
if autoCond == nil { if autoCond == nil {
autoCond = autoCond1 autoCond = autoCond1
@ -294,15 +294,15 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
st := &session.statement st := session.statement
var ( var (
sqlStr string sqlStr string
condArgs []interface{} condArgs []interface{}
condSQL string condSQL string
cond = session.statement.cond.And(autoCond) cond = session.statement.Conds().And(autoCond)
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.checkVersion) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
verValue *reflect.Value verValue *reflect.Value
) )
if doIncVer { if doIncVer {
@ -335,9 +335,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var top string var top string
if st.LimitN != nil { if st.LimitN != nil {
limitValue := *st.LimitN limitValue := *st.LimitN
if st.dialect.DBType() == schemas.MYSQL { if session.engine.dialect.DBType() == schemas.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue) condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.dialect.DBType() == schemas.SQLITE { } else if session.engine.dialect.DBType() == schemas.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -348,7 +348,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if st.dialect.DBType() == schemas.POSTGRES { } else if session.engine.dialect.DBType() == schemas.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...)) session.engine.Quote(tableName), tempCondSQL), condArgs...))
@ -360,8 +360,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
} else if st.dialect.DBType() == schemas.MSSQL { } else if session.engine.dialect.DBType() == schemas.MSSQL {
if st.OrderStr != "" && st.dialect.DBType() == schemas.MSSQL && if st.OrderStr != "" && session.engine.dialect.DBType() == schemas.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 { table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
@ -459,7 +459,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
for _, col := range table.Columns() { for _, col := range table.Columns() {
if !col.IsVersion && !col.IsCreated && !col.IsUpdated { if !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if session.statement.omitColumnMap.contain(col.Name) { if session.statement.OmitColumnMap.Contain(col.Name) {
continue continue
} }
} }
@ -494,25 +494,25 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
} }
} }
if (col.IsDeleted && !session.statement.unscoped) || col.IsCreated { if (col.IsDeleted && !session.statement.GetUnscoped()) || col.IsCreated {
continue continue
} }
// if only update specify columns // if only update specify columns
if len(session.statement.columnMap) > 0 && !session.statement.columnMap.contain(col.Name) { if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue continue
} }
if session.statement.incrColumns.isColExist(col.Name) { if session.statement.IncrColumns.IsColExist(col.Name) {
continue continue
} else if session.statement.decrColumns.isColExist(col.Name) { } else if session.statement.DecrColumns.IsColExist(col.Name) {
continue continue
} else if session.statement.exprColumns.isColExist(col.Name) { } else if session.statement.ExprColumns.IsColExist(col.Name) {
continue continue
} }
// !evalphobia! set fieldValue as nil when column is nullable and zero-value // !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok { if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
if col.Nullable && utils.IsValueZero(fieldValue) { if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int var nilValue *int
fieldValue = reflect.ValueOf(nilValue) fieldValue = reflect.ValueOf(nilValue)
@ -529,7 +529,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
col := table.GetColumn(colName) col := table.GetColumn(colName)
setColumnTime(bean, col, t) setColumnTime(bean, col, t)
}) })
} else if col.IsVersion && session.statement.checkVersion { } else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1) args = append(args, 1)
} else { } else {
arg, err := session.value2Interface(col, fieldValue) arg, err := session.value2Interface(col, fieldValue)

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
) )
@ -685,20 +686,20 @@ func TestUpdateSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
oldMapper := testEngine.GetTableMapper() oldMapper := testEngine.GetTableMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type())
testEngine.UnMapType(rValue(new(Article)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Article)).Type())
testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type())
testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type())
testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type())
testEngine.SetMapper(names.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Condi)).Type())
testEngine.UnMapType(rValue(new(Article)).Type()) testEngine.UnMapType(utils.ReflectValue(new(Article)).Type())
testEngine.UnMapType(rValue(new(UpdateAllCols)).Type()) testEngine.UnMapType(utils.ReflectValue(new(UpdateAllCols)).Type())
testEngine.UnMapType(rValue(new(UpdateMustCols)).Type()) testEngine.UnMapType(utils.ReflectValue(new(UpdateMustCols)).Type())
testEngine.UnMapType(rValue(new(UpdateIncr)).Type()) testEngine.UnMapType(utils.ReflectValue(new(UpdateIncr)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)
}() }()

View File

@ -5,185 +5,11 @@
package xorm package xorm
import ( import (
"reflect"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm/schemas"
) )
var colStrTests = []struct {
omitColumn string
onlyToDBColumnNdx int
expected string
}{
{"", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"Code2", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"},
{"Code3", 1, "`ID`, `Caption`, `Code1`, `Code2`, `ParentID`, `Latitude`, `Longitude`"},
{"Longitude", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"},
{"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"},
}
func TestColumnsStringGeneration(t *testing.T) {
if dbType == "postgres" || dbType == "mssql" {
return
}
var statement *Statement
for ndx, testCase := range colStrTests {
statement = createTestStatement()
if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn)
}
columns := statement.RefTable.Columns()
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB
}
actual := statement.genColumnStr()
if actual != testCase.expected {
t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual)
}
if testCase.onlyToDBColumnNdx >= 0 {
columns[testCase.onlyToDBColumnNdx].MapType = schemas.TWOSIDES
}
}
}
func BenchmarkColumnsStringGeneration(b *testing.B) {
b.StopTimer()
statement := createTestStatement()
testCase := colStrTests[0]
if testCase.omitColumn != "" {
statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped
}
if testCase.onlyToDBColumnNdx >= 0 {
columns := statement.RefTable.Columns()
columns[testCase.onlyToDBColumnNdx].MapType = schemas.ONLYTODB // !nemec784! Column must be skipped
}
b.StartTimer()
for i := 0; i < b.N; i++ {
actual := statement.genColumnStr()
if actual != testCase.expected {
b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual)
}
}
}
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
b.StopTimer()
mapCols := make(map[string]bool)
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
{Name: `Code1`},
{Name: `Code2`},
{Name: `Code3`},
{Name: `ParentID`},
{Name: `Latitude`},
{Name: `Longitude`},
}
for _, col := range cols {
mapCols[strings.ToLower(col.Name)] = true
}
b.StartTimer()
for i := 0; i < b.N; i++ {
for _, col := range cols {
if _, ok := getFlagForColumn(mapCols, col); !ok {
b.Fatal("Unexpected result")
}
}
}
}
func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
b.StopTimer()
mapCols := make(map[string]bool)
cols := []*schemas.Column{
{Name: `ID`},
{Name: `IsDeleted`},
{Name: `Caption`},
{Name: `Code1`},
{Name: `Code2`},
{Name: `Code3`},
{Name: `ParentID`},
{Name: `Latitude`},
{Name: `Longitude`},
}
b.StartTimer()
for i := 0; i < b.N; i++ {
for _, col := range cols {
if _, ok := getFlagForColumn(mapCols, col); ok {
b.Fatal("Unexpected result")
}
}
}
}
type TestType struct {
ID int64 `xorm:"ID PK"`
IsDeleted bool `xorm:"IsDeleted"`
Caption string `xorm:"Caption"`
Code1 string `xorm:"Code1"`
Code2 string `xorm:"Code2"`
Code3 string `xorm:"Code3"`
ParentID int64 `xorm:"ParentID"`
Latitude float64 `xorm:"Latitude"`
Longitude float64 `xorm:"Longitude"`
}
func (TestType) TableName() string {
return "TestTable"
}
func createTestStatement() *Statement {
if engine, ok := testEngine.(*Engine); ok {
statement := &Statement{}
statement.Reset()
statement.Engine = engine
statement.dialect = engine.dialect
statement.setRefValue(reflect.ValueOf(TestType{}))
return statement
} else if eg, ok := testEngine.(*EngineGroup); ok {
statement := &Statement{}
statement.Reset()
statement.Engine = eg.Engine
statement.dialect = eg.Engine.dialect
statement.setRefValue(reflect.ValueOf(TestType{}))
return statement
}
return nil
}
func TestDistinctAndCols(t *testing.T) { func TestDistinctAndCols(t *testing.T) {
type DistinctAndCols struct { type DistinctAndCols struct {
Id int64 Id int64

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -608,10 +609,10 @@ func TestGonicMapperID(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type())
testEngine.SetMapper(names.LintGonicMapper) testEngine.SetMapper(names.LintGonicMapper)
defer func() { defer func() {
testEngine.UnMapType(rValue(new(IDGonicMapper)).Type()) testEngine.UnMapType(utils.ReflectValue(new(IDGonicMapper)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)
}() }()
@ -645,10 +646,10 @@ func TestSameMapperID(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetColumnMapper()
testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type())
testEngine.SetMapper(names.SameMapper{}) testEngine.SetMapper(names.SameMapper{})
defer func() { defer func() {
testEngine.UnMapType(rValue(new(IDSameMapper)).Type()) testEngine.UnMapType(utils.ReflectValue(new(IDSameMapper)).Type())
testEngine.SetMapper(oldMapper) testEngine.SetMapper(oldMapper)
}() }()
@ -818,7 +819,9 @@ func TestAutoIncrTag(t *testing.T) {
Id int64 Id int64
} }
tb := testEngine.TableInfo(new(TestAutoIncr1)) tb, err := testEngine.TableInfo(new(TestAutoIncr1))
assert.NoError(t, err)
cols := tb.Columns() cols := tb.Columns()
assert.EqualValues(t, 1, len(cols)) assert.EqualValues(t, 1, len(cols))
assert.True(t, cols[0].IsAutoIncrement) assert.True(t, cols[0].IsAutoIncrement)
@ -829,7 +832,9 @@ func TestAutoIncrTag(t *testing.T) {
Id int64 `xorm:"id"` Id int64 `xorm:"id"`
} }
tb = testEngine.TableInfo(new(TestAutoIncr2)) tb, err = testEngine.TableInfo(new(TestAutoIncr2))
assert.NoError(t, err)
cols = tb.Columns() cols = tb.Columns()
assert.EqualValues(t, 1, len(cols)) assert.EqualValues(t, 1, len(cols))
assert.False(t, cols[0].IsAutoIncrement) assert.False(t, cols[0].IsAutoIncrement)
@ -840,7 +845,9 @@ func TestAutoIncrTag(t *testing.T) {
Id int64 `xorm:"'ID'"` Id int64 `xorm:"'ID'"`
} }
tb = testEngine.TableInfo(new(TestAutoIncr3)) tb, err = testEngine.TableInfo(new(TestAutoIncr3))
assert.NoError(t, err)
cols = tb.Columns() cols = tb.Columns()
assert.EqualValues(t, 1, len(cols)) assert.EqualValues(t, 1, len(cols))
assert.False(t, cols[0].IsAutoIncrement) assert.False(t, cols[0].IsAutoIncrement)
@ -851,7 +858,9 @@ func TestAutoIncrTag(t *testing.T) {
Id int64 `xorm:"pk"` Id int64 `xorm:"pk"`
} }
tb = testEngine.TableInfo(new(TestAutoIncr4)) tb, err = testEngine.TableInfo(new(TestAutoIncr4))
assert.NoError(t, err)
cols = tb.Columns() cols = tb.Columns()
assert.EqualValues(t, 1, len(cols)) assert.EqualValues(t, 1, len(cols))
assert.False(t, cols[0].IsAutoIncrement) assert.False(t, cols[0].IsAutoIncrement)
@ -1035,7 +1044,9 @@ func TestTagDefault5(t *testing.T) {
} }
assertSync(t, new(DefaultStruct5)) assertSync(t, new(DefaultStruct5))
table := testEngine.TableInfo(new(DefaultStruct5)) table, err := testEngine.TableInfo(new(DefaultStruct5))
assert.NoError(t, err)
createdCol := table.GetColumn("created") createdCol := table.GetColumn("created")
assert.NotNil(t, createdCol) assert.NotNil(t, createdCol)
assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default)

View File

@ -10,6 +10,7 @@ import (
"testing" "testing"
"xorm.io/xorm/convert" "xorm.io/xorm/convert"
"xorm.io/xorm/internal/json"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -118,21 +119,21 @@ type ConvConfig struct {
} }
func (s *ConvConfig) FromDB(data []byte) error { func (s *ConvConfig) FromDB(data []byte) error {
return DefaultJSONHandler.Unmarshal(data, s) return json.DefaultJSONHandler.Unmarshal(data, s)
} }
func (s *ConvConfig) ToDB() ([]byte, error) { func (s *ConvConfig) ToDB() ([]byte, error) {
return DefaultJSONHandler.Marshal(s) return json.DefaultJSONHandler.Marshal(s)
} }
type SliceType []*ConvConfig type SliceType []*ConvConfig
func (s *SliceType) FromDB(data []byte) error { func (s *SliceType) FromDB(data []byte) error {
return DefaultJSONHandler.Unmarshal(data, s) return json.DefaultJSONHandler.Unmarshal(data, s)
} }
func (s *SliceType) ToDB() ([]byte, error) { func (s *SliceType) ToDB() ([]byte, error) {
return DefaultJSONHandler.Marshal(s) return json.DefaultJSONHandler.Marshal(s)
} }
type ConvStruct struct { type ConvStruct struct {