refactor table name

This commit is contained in:
Lunny Xiao 2020-01-29 13:17:11 +08:00
parent 14a0c19a0c
commit 1b71702c21
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
10 changed files with 202 additions and 108 deletions

View File

@ -217,12 +217,13 @@ func quoteTo(buf *strings.Builder, quotePair string, value string) {
} }
prefix, suffix := quotePair[0], quotePair[1] prefix, suffix := quotePair[0], quotePair[1]
lastCh := 0 // 0 prefix, 1 char, 2 suffix
i := 0 i := 0
for i < len(value) { for i < len(value) {
// start of a token; might be already quoted // start of a token; might be already quoted
if value[i] == '.' { if value[i] == '.' {
_ = buf.WriteByte('.') _ = buf.WriteByte('.')
lastCh = 1
i++ i++
} else if value[i] == prefix || value[i] == '`' { } else if value[i] == prefix || value[i] == '`' {
// Has quotes; skip/normalize `name` to prefix+name+sufix // Has quotes; skip/normalize `name` to prefix+name+sufix
@ -234,18 +235,37 @@ func quoteTo(buf *strings.Builder, quotePair string, value string) {
} }
i++ i++
_ = buf.WriteByte(prefix) _ = buf.WriteByte(prefix)
for ; i < len(value) && value[i] != ch; i++ { lastCh = 0
for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i]) _ = buf.WriteByte(value[i])
lastCh = 1
} }
_ = buf.WriteByte(suffix) _ = buf.WriteByte(suffix)
lastCh = 2
i++ i++
} else if value[i] == ' ' {
if lastCh != 2 {
_ = buf.WriteByte(suffix)
lastCh = 2
}
// a AS b or a b
for ; i < len(value); i++ {
if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) {
break
}
_ = buf.WriteByte(value[i])
lastCh = 1
}
} else { } else {
// Requires quotes // Requires quotes
_ = buf.WriteByte(prefix) _ = buf.WriteByte(prefix)
for ; i < len(value) && value[i] != '.'; i++ { for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i]) _ = buf.WriteByte(value[i])
lastCh = 1
} }
_ = buf.WriteByte(suffix) _ = buf.WriteByte(suffix)
lastCh = 2
} }
} }
} }
@ -918,10 +938,18 @@ var (
) )
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
table := core.NewEmptyTable() table := core.NewEmptyTable()
table.Name = tbNameForMap(engine.TableMapper, v)
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
table.Type = t table.Type = t
table.Name = engine.tbNameForMap(v)
fmt.Println("======", table.Name)
var idFieldColName string var idFieldColName string
var hasCacheTag, hasNoCacheTag bool var hasCacheTag, hasNoCacheTag bool

View File

@ -27,46 +27,61 @@ func (engine *Engine) tbNameWithSchema(v string) string {
// TableName returns table name with schema prefix if has // TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean) tbName, _ := newTableName(engine.TableMapper, bean)
if len(includeSchema) > 0 && includeSchema[0] { if len(includeSchema) > 0 && includeSchema[0] {
tbName = engine.tbNameWithSchema(tbName) tbName.schema = engine.dialect.URI().Schema
return tbName.withSchema()
} }
return tbName return tbName.withNoSchema()
} }
// tbName get some table's table name // tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string { func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 { if len(session.statement.altTableName) > 0 {
return session.statement.AltTableName return session.statement.altTableName
} }
return table.Name return table.Name
} }
func (engine *Engine) tbNameForMap(v reflect.Value) string { func tbNameForMap(mapper core.IMapper, v reflect.Value) string {
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}
if v.Type().Implements(tpTableName) { if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName() return v.Interface().(TableName).TableName()
} }
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
v = v.Elem() v = v.Elem()
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}
if v.Type().Implements(tpTableName) { if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName() return v.Interface().(TableName).TableName()
} }
} }
return engine.TableMapper.Obj2Table(v.Type().Name()) return mapper.Obj2Table(v.Type().Name())
} }
func (engine *Engine) tbNameNoSchema(tablename interface{}) string { type tableName struct {
name string
schema string
alias string
aliasSplitter string
}
func newTableName(mapper core.IMapper, tablename interface{}) (tableName, error) {
switch tablename.(type) { switch tablename.(type) {
case []string: case []string:
t := tablename.([]string) t := tablename.([]string)
if len(t) > 1 { if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) return tableName{name: t[0], alias: t[1]}, nil
} else if len(t) == 1 { } else if len(t) == 1 {
return engine.Quote(t[0]) return tableName{name: t[0]}, nil
} }
return tableName{}, ErrTableNotFound
case []interface{}: case []interface{}:
t := tablename.([]interface{}) t := tablename.([]interface{})
l := len(t) l := len(t)
@ -82,32 +97,56 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
v := rValue(f) v := rValue(f)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
table = engine.tbNameForMap(v) table = tbNameForMap(mapper, v)
} else { } else {
table = engine.Quote(fmt.Sprintf("%v", f)) table = fmt.Sprintf("%v", f)
} }
} }
} }
if l > 1 { if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table), return tableName{name: table, alias: fmt.Sprintf("%v", t[1])}, nil
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 { } else if l == 1 {
return engine.Quote(table) return tableName{name: table}, nil
} }
case TableName: case TableName:
return tablename.(TableName).TableName() fmt.Println("+++++++++++++++++++++++++", tablename.(TableName).TableName())
return tableName{name: tablename.(TableName).TableName()}, nil
case string: case string:
return tablename.(string) return tableName{name: tablename.(string)}, nil
case reflect.Value: case reflect.Value:
v := tablename.(reflect.Value) v := tablename.(reflect.Value)
return engine.tbNameForMap(v) return tableName{name: tbNameForMap(mapper, v)}, nil
default: default:
v := rValue(tablename) v := rValue(tablename)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
return engine.tbNameForMap(v) return tableName{name: tbNameForMap(mapper, v)}, nil
} }
return engine.Quote(fmt.Sprintf("%v", tablename)) return tableName{name: fmt.Sprintf("%v", tablename)}, nil
} }
return "" return tableName{}, ErrTableNotFound
}
func (t tableName) withSchema() string {
if t.schema == "" {
return t.withNoSchema()
}
if t.alias != "" {
if t.aliasSplitter != "" {
return fmt.Sprintf("%s.%s %s %s", t.schema, t.name, t.aliasSplitter, t.alias)
}
return fmt.Sprintf("%s.%s %s", t.schema, t.name, t.alias)
}
return fmt.Sprintf("%s.%s", t.schema, t.name)
}
func (t tableName) withNoSchema() string {
if t.alias != "" {
if t.aliasSplitter != "" {
return fmt.Sprintf("%s %s %s", t.name, t.aliasSplitter, t.alias)
}
return fmt.Sprintf("%s %s", t.name, t.alias)
}
return t.name
} }

View File

@ -115,8 +115,8 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
var colName = session.engine.Quote(col.Name) var colName = session.engine.Quote(col.Name)
if addedTableName { if addedTableName {
var nm = session.statement.TableName() var nm = session.statement.TableName()
if len(session.statement.TableAlias) > 0 { if len(session.statement.tableAlias) > 0 {
nm = session.statement.TableAlias nm = session.statement.tableAlias
} }
colName = session.engine.Quote(nm) + "." + colName colName = session.engine.Quote(nm) + "." + colName
} }

View File

@ -10,8 +10,8 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
func TestJoinLimit(t *testing.T) { func TestJoinLimit(t *testing.T) {
@ -801,3 +801,25 @@ func TestFindJoin(t *testing.T) {
Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes)
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestJoinReverseWord(t *testing.T) {
type JoinReverseWord struct {
Id int64
Name string
}
type JoinReverseWord2 struct {
Id int64
UserId int64 `xorm:"index"`
Age int
}
assert.NoError(t, prepareEngine())
err := testEngine.Table("order").Sync2(new(JoinReverseWord))
assert.NoError(t, err)
assertSync(t, new(JoinReverseWord2))
var j2 []JoinReverseWord2
err = testEngine.Join("INNER", "order", "`join_reverse_word2`.user_id=`order`.id").Find(&j2)
assert.NoError(t, err)
}

View File

@ -492,6 +492,8 @@ func TestGetCustomTableInterface(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exist) assert.True(t, exist)
assert.EqualValues(t, getCustomTableName, testEngine.TableInfo(new(MyGetCustomTableImpletation)).Name)
_, err = testEngine.Insert(&MyGetCustomTableImpletation{ _, err = testEngine.Insert(&MyGetCustomTableImpletation{
Name: "xlw", Name: "xlw",
}) })

View File

@ -324,7 +324,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefBean(bean); err != nil { if err := session.statement.setRefBean(bean); err != nil {
return 0, err return 0, err
} }
if len(session.statement.TableName()) <= 0 { var tableName = session.statement.TableName()
fmt.Println("------", tableName)
if len(tableName) <= 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
@ -351,7 +353,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
colPlaces = colPlaces[0 : len(colPlaces)-2] colPlaces = colPlaces[0 : len(colPlaces)-2]
} }
var tableName = session.statement.TableName()
var output string var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)

View File

@ -245,9 +245,13 @@ func (session *Session) Sync2(beans ...interface{}) error {
if err != nil { if err != nil {
return err return err
} }
var tbName string
if len(session.statement.AltTableName) > 0 { var (
tbName = session.statement.AltTableName tbName string
altTableName = session.statement.altTableName
)
if len(altTableName) > 0 {
tbName = altTableName
} else { } else {
tbName = engine.TableName(bean) tbName = engine.TableName(bean)
} }
@ -298,7 +302,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// column is not exist on table // column is not exist on table
if oriCol == nil { if oriCol == nil {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.altTableName = altTableName
if err = session.addColumn(col.Name); err != nil { if err = session.addColumn(col.Name); err != nil {
return err return err
} }
@ -406,11 +410,11 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.altTableName = altTableName
err = session.addUnique(tbNameWithSchema, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema session.statement.altTableName = altTableName
err = session.addIndex(tbNameWithSchema, name) err = session.addIndex(tbNameWithSchema, name)
} }
if err != nil { if err != nil {

View File

@ -389,13 +389,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableAlias = session.engine.Quote(tableName) var tableAlias = session.engine.Quote(tableName)
var fromSQL string var fromSQL string
if session.statement.TableAlias != "" { if session.statement.tableAlias != "" {
switch session.engine.dialect.DBType() { switch session.engine.dialect.DBType() {
case core.MSSQL: case core.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.tableAlias)
tableAlias = session.statement.TableAlias tableAlias = session.statement.tableAlias
default: default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias) tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.tableAlias)
} }
} }

View File

@ -31,8 +31,7 @@ type Statement struct {
selectStr string selectStr string
useAllCols bool useAllCols bool
OmitStr string OmitStr string
AltTableName string altTableName string
tableName string
RawSQL string RawSQL string
RawParams []interface{} RawParams []interface{}
UseCascade bool UseCascade bool
@ -44,7 +43,7 @@ type Statement struct {
noAutoCondition bool noAutoCondition bool
IsDistinct bool IsDistinct bool
IsForUpdate bool IsForUpdate bool
TableAlias string tableAlias string
allUseBool bool allUseBool bool
checkVersion bool checkVersion bool
unscoped bool unscoped bool
@ -76,8 +75,7 @@ func (statement *Statement) Init() {
statement.OmitStr = "" statement.OmitStr = ""
statement.columnMap = columnMap{} statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{} statement.omitColumnMap = columnMap{}
statement.AltTableName = "" statement.altTableName = ""
statement.tableName = ""
statement.idParam = nil statement.idParam = nil
statement.RawSQL = "" statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0) statement.RawParams = make([]interface{}, 0)
@ -86,7 +84,7 @@ func (statement *Statement) Init() {
statement.noAutoCondition = false statement.noAutoCondition = false
statement.IsDistinct = false statement.IsDistinct = false
statement.IsForUpdate = false statement.IsForUpdate = false
statement.TableAlias = "" statement.tableAlias = ""
statement.selectStr = "" statement.selectStr = ""
statement.allUseBool = false statement.allUseBool = false
statement.useAllCols = false statement.useAllCols = false
@ -114,7 +112,7 @@ func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
// Alias set the table alias // Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement { func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias statement.tableAlias = alias
return statement return statement
} }
@ -209,22 +207,31 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
func (statement *Statement) setRefValue(v reflect.Value) error { func (statement *Statement) setRefValue(v reflect.Value) error {
var err error var err error
statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v)) statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil { return err
return err
}
statement.tableName = statement.Engine.TableName(v, true)
return nil
} }
func (statement *Statement) setRefBean(bean interface{}) error { func (statement *Statement) setRefBean(bean interface{}) error {
var err error return statement.setRefValue(reflect.ValueOf(bean))
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) }
if err != nil {
return err func (statement *Statement) getTableName() tableName {
var name = statement.altTableName
if name == "" && statement.RefTable != nil {
name = statement.RefTable.Name
}
var aliasSplitter = "AS"
if statement.Engine.dialect.DBType() == core.MSSQL {
aliasSplitter = ""
}
return tableName{
name: name,
alias: statement.tableAlias,
aliasSplitter: aliasSplitter,
schema: statement.Engine.Dialect().URI().Schema,
} }
statement.tableName = statement.Engine.TableName(bean, true)
return nil
} }
// Auto generating update columnes and values according a struct // Auto generating update columnes and values according a struct
@ -492,28 +499,27 @@ func (statement *Statement) buildUpdates(bean interface{},
return colNames, args return colNames, args
} }
func (statement *Statement) needTableName() bool { func (statement *Statement) colsNeedTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.JoinStr) > 0
} }
func (statement *Statement) colName(col *core.Column, tableName string) string { func (statement *Statement) writeColName(buf *strings.Builder, colName string) {
if statement.needTableName() { quotePair := statement.Engine.Dialect().Quote("")
var nm = tableName if statement.colsNeedTableName() {
if len(statement.TableAlias) > 0 { tbname := statement.getTableName()
nm = statement.TableAlias quoteTo(buf, quotePair, tbname.withSchema())
} buf.WriteByte('.')
return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
} }
return statement.Engine.Quote(col.Name) quoteTo(buf, quotePair, colName)
} }
// TableName return current tableName // fullColName return a column name with schema/table name and quotes
func (statement *Statement) TableName() string { func (statement *Statement) fullColName(colName string) string {
if statement.AltTableName != "" { if statement.colsNeedTableName() {
return statement.AltTableName tbname := statement.getTableName()
return tbname.withSchema() + "." + statement.Engine.Quote(colName)
} }
return statement.Engine.Quote(colName)
return statement.tableName
} }
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
@ -716,18 +722,24 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean) v := rValue(tableNameOrBean)
t := v.Type() t := v.Type()
if t.Kind() == reflect.Struct { if t.Kind() == reflect.Struct {
var err error statement.setRefValue(v)
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
} }
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) statement.altTableName = statement.Engine.TableName(tableNameOrBean, false)
return statement return statement
} }
// TableName return table name
func (statement *Statement) TableName() string {
if statement.altTableName != "" {
return statement.altTableName
}
if statement.RefTable != nil {
return statement.RefTable.Name
}
return ""
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder var buf strings.Builder
@ -764,7 +776,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
statement.joinArgs = append(statement.joinArgs, subQueryArgs...) statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default: default:
tbName := statement.Engine.TableName(tablename, true) tbName := statement.Engine.TableName(tablename, true)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition) fmt.Println("------", tbName)
fmt.Fprintf(&buf, "%s ON %v", statement.Engine.Quote(tbName), condition)
} }
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
@ -815,17 +828,7 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ") buf.WriteString(", ")
} }
if statement.JoinStr != "" { statement.writeColName(&buf, col.Name)
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.Engine.QuoteTo(&buf, col.Name)
} }
return buf.String() return buf.String()
@ -902,7 +905,7 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.tableAlias, addedTableName)
} }
func (statement *Statement) mergeConds(bean interface{}) error { func (statement *Statement) mergeConds(bean interface{}) error {
@ -1060,11 +1063,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
fromStr += quote(statement.TableName()) fromStr += quote(statement.TableName())
} }
if statement.TableAlias != "" { if statement.tableAlias != "" {
if dialect.DBType() == core.ORACLE { if dialect.DBType() == core.ORACLE {
fromStr += " " + quote(statement.TableAlias) fromStr += " " + quote(statement.tableAlias)
} else { } else {
fromStr += " AS " + quote(statement.TableAlias) fromStr += " AS " + quote(statement.tableAlias)
} }
} }
if statement.JoinStr != "" { if statement.JoinStr != "" {
@ -1090,13 +1093,8 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} else { } else {
column = statement.RefTable.PKColumns()[0].Name column = statement.RefTable.PKColumns()[0].Name
} }
if statement.needTableName() {
if len(statement.TableAlias) > 0 { column = statement.fullColName(column)
column = statement.TableAlias + "." + column
} else {
column = statement.TableName() + "." + column
}
}
var orderStr string var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 { if needOrderBy && len(statement.OrderStr) > 0 {
@ -1171,7 +1169,7 @@ func (statement *Statement) processIDParam() error {
} }
for i, col := range statement.RefTable.PKColumns() { for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName()) var colName = statement.fullColName(col.Name)
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
} }
return nil return nil

View File

@ -10,8 +10,8 @@ import (
"testing" "testing"
"time" "time"
"xorm.io/core"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"xorm.io/core"
) )
type tempUser struct { type tempUser struct {