refactor table name

This commit is contained in:
Lunny Xiao 2020-01-29 13:17:11 +08:00
parent 14a0c19a0c
commit 1b71702c21
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
10 changed files with 202 additions and 108 deletions

View File

@ -215,14 +215,15 @@ func quoteTo(buf *strings.Builder, quotePair string, value string) {
_, _ = buf.WriteString(value)
return
}
prefix, suffix := quotePair[0], quotePair[1]
prefix, suffix := quotePair[0], quotePair[1]
lastCh := 0 // 0 prefix, 1 char, 2 suffix
i := 0
for i < len(value) {
// start of a token; might be already quoted
if value[i] == '.' {
_ = buf.WriteByte('.')
lastCh = 1
i++
} else if value[i] == prefix || value[i] == '`' {
// Has quotes; skip/normalize `name` to prefix+name+sufix
@ -234,18 +235,37 @@ func quoteTo(buf *strings.Builder, quotePair string, value string) {
}
i++
_ = buf.WriteByte(prefix)
for ; i < len(value) && value[i] != ch; i++ {
lastCh = 0
for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i])
lastCh = 1
}
_ = buf.WriteByte(suffix)
lastCh = 2
i++
} else if value[i] == ' ' {
if lastCh != 2 {
_ = buf.WriteByte(suffix)
lastCh = 2
}
// a AS b or a b
for ; i < len(value); i++ {
if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) {
break
}
_ = buf.WriteByte(value[i])
lastCh = 1
}
} else {
// Requires quotes
_ = buf.WriteByte(prefix)
for ; i < len(value) && value[i] != '.'; i++ {
for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ {
_ = buf.WriteByte(value[i])
lastCh = 1
}
_ = buf.WriteByte(suffix)
lastCh = 2
}
}
}
@ -918,10 +938,18 @@ var (
)
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
table := core.NewEmptyTable()
table.Name = tbNameForMap(engine.TableMapper, v)
t := v.Type()
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
table.Type = t
table.Name = engine.tbNameForMap(v)
fmt.Println("======", table.Name)
var idFieldColName string
var hasCacheTag, hasNoCacheTag bool

View File

@ -27,46 +27,61 @@ func (engine *Engine) tbNameWithSchema(v string) string {
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
tbName, _ := newTableName(engine.TableMapper, bean)
if len(includeSchema) > 0 && includeSchema[0] {
tbName = engine.tbNameWithSchema(tbName)
tbName.schema = engine.dialect.URI().Schema
return tbName.withSchema()
}
return tbName
return tbName.withNoSchema()
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
if len(session.statement.altTableName) > 0 {
return session.statement.altTableName
}
return table.Name
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {
func tbNameForMap(mapper core.IMapper, v reflect.Value) string {
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if t, ok := v.Interface().(TableName); ok {
return t.TableName()
}
if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName()
}
}
return engine.TableMapper.Obj2Table(v.Type().Name())
return mapper.Obj2Table(v.Type().Name())
}
func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
type tableName struct {
name string
schema string
alias string
aliasSplitter string
}
func newTableName(mapper core.IMapper, tablename interface{}) (tableName, error) {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
return tableName{name: t[0], alias: t[1]}, nil
} else if len(t) == 1 {
return engine.Quote(t[0])
return tableName{name: t[0]}, nil
}
return tableName{}, ErrTableNotFound
case []interface{}:
t := tablename.([]interface{})
l := len(t)
@ -82,32 +97,56 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
table = engine.tbNameForMap(v)
table = tbNameForMap(mapper, v)
} else {
table = engine.Quote(fmt.Sprintf("%v", f))
table = fmt.Sprintf("%v", f)
}
}
}
if l > 1 {
return fmt.Sprintf("%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
return tableName{name: table, alias: fmt.Sprintf("%v", t[1])}, nil
} else if l == 1 {
return engine.Quote(table)
return tableName{name: table}, nil
}
case TableName:
return tablename.(TableName).TableName()
fmt.Println("+++++++++++++++++++++++++", tablename.(TableName).TableName())
return tableName{name: tablename.(TableName).TableName()}, nil
case string:
return tablename.(string)
return tableName{name: tablename.(string)}, nil
case reflect.Value:
v := tablename.(reflect.Value)
return engine.tbNameForMap(v)
return tableName{name: tbNameForMap(mapper, v)}, nil
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
return engine.tbNameForMap(v)
return tableName{name: tbNameForMap(mapper, v)}, nil
}
return engine.Quote(fmt.Sprintf("%v", tablename))
return tableName{name: fmt.Sprintf("%v", tablename)}, nil
}
return ""
return tableName{}, ErrTableNotFound
}
func (t tableName) withSchema() string {
if t.schema == "" {
return t.withNoSchema()
}
if t.alias != "" {
if t.aliasSplitter != "" {
return fmt.Sprintf("%s.%s %s %s", t.schema, t.name, t.aliasSplitter, t.alias)
}
return fmt.Sprintf("%s.%s %s", t.schema, t.name, t.alias)
}
return fmt.Sprintf("%s.%s", t.schema, t.name)
}
func (t tableName) withNoSchema() string {
if t.alias != "" {
if t.aliasSplitter != "" {
return fmt.Sprintf("%s %s %s", t.name, t.aliasSplitter, t.alias)
}
return fmt.Sprintf("%s %s", t.name, t.alias)
}
return t.name
}

View File

@ -115,8 +115,8 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
var colName = session.engine.Quote(col.Name)
if addedTableName {
var nm = session.statement.TableName()
if len(session.statement.TableAlias) > 0 {
nm = session.statement.TableAlias
if len(session.statement.tableAlias) > 0 {
nm = session.statement.tableAlias
}
colName = session.engine.Quote(nm) + "." + colName
}

View File

@ -10,8 +10,8 @@ import (
"testing"
"time"
"xorm.io/core"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)
func TestJoinLimit(t *testing.T) {
@ -801,3 +801,25 @@ func TestFindJoin(t *testing.T) {
Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes)
assert.NoError(t, err)
}
func TestJoinReverseWord(t *testing.T) {
type JoinReverseWord struct {
Id int64
Name string
}
type JoinReverseWord2 struct {
Id int64
UserId int64 `xorm:"index"`
Age int
}
assert.NoError(t, prepareEngine())
err := testEngine.Table("order").Sync2(new(JoinReverseWord))
assert.NoError(t, err)
assertSync(t, new(JoinReverseWord2))
var j2 []JoinReverseWord2
err = testEngine.Join("INNER", "order", "`join_reverse_word2`.user_id=`order`.id").Find(&j2)
assert.NoError(t, err)
}

View File

@ -492,6 +492,8 @@ func TestGetCustomTableInterface(t *testing.T) {
assert.NoError(t, err)
assert.True(t, exist)
assert.EqualValues(t, getCustomTableName, testEngine.TableInfo(new(MyGetCustomTableImpletation)).Name)
_, err = testEngine.Insert(&MyGetCustomTableImpletation{
Name: "xlw",
})

View File

@ -324,7 +324,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.setRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
var tableName = session.statement.TableName()
fmt.Println("------", tableName)
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
@ -351,7 +353,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
colPlaces = colPlaces[0 : len(colPlaces)-2]
}
var tableName = session.statement.TableName()
var output string
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)

View File

@ -245,9 +245,13 @@ func (session *Session) Sync2(beans ...interface{}) error {
if err != nil {
return err
}
var tbName string
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
var (
tbName string
altTableName = session.statement.altTableName
)
if len(altTableName) > 0 {
tbName = altTableName
} else {
tbName = engine.TableName(bean)
}
@ -298,7 +302,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
session.statement.altTableName = altTableName
if err = session.addColumn(col.Name); err != nil {
return err
}
@ -406,11 +410,11 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames {
if index.Type == core.UniqueType {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
session.statement.altTableName = altTableName
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType {
session.statement.RefTable = table
session.statement.tableName = tbNameWithSchema
session.statement.altTableName = altTableName
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {

View File

@ -389,13 +389,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableAlias = session.engine.Quote(tableName)
var fromSQL string
if session.statement.TableAlias != "" {
if session.statement.tableAlias != "" {
switch session.engine.dialect.DBType() {
case core.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.tableAlias)
tableAlias = session.statement.tableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias)
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.tableAlias)
}
}

View File

@ -31,8 +31,7 @@ type Statement struct {
selectStr string
useAllCols bool
OmitStr string
AltTableName string
tableName string
altTableName string
RawSQL string
RawParams []interface{}
UseCascade bool
@ -44,7 +43,7 @@ type Statement struct {
noAutoCondition bool
IsDistinct bool
IsForUpdate bool
TableAlias string
tableAlias string
allUseBool bool
checkVersion bool
unscoped bool
@ -76,8 +75,7 @@ func (statement *Statement) Init() {
statement.OmitStr = ""
statement.columnMap = columnMap{}
statement.omitColumnMap = columnMap{}
statement.AltTableName = ""
statement.tableName = ""
statement.altTableName = ""
statement.idParam = nil
statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0)
@ -86,7 +84,7 @@ func (statement *Statement) Init() {
statement.noAutoCondition = false
statement.IsDistinct = false
statement.IsForUpdate = false
statement.TableAlias = ""
statement.tableAlias = ""
statement.selectStr = ""
statement.allUseBool = false
statement.useAllCols = false
@ -114,7 +112,7 @@ func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
// Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement {
statement.TableAlias = alias
statement.tableAlias = alias
return statement
}
@ -209,22 +207,31 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
func (statement *Statement) setRefValue(v reflect.Value) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
if err != nil {
return err
}
statement.tableName = statement.Engine.TableName(v, true)
return nil
statement.RefTable, err = statement.Engine.autoMapType(v)
return err
}
func (statement *Statement) setRefBean(bean interface{}) error {
var err error
statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
if err != nil {
return err
return statement.setRefValue(reflect.ValueOf(bean))
}
func (statement *Statement) getTableName() tableName {
var name = statement.altTableName
if name == "" && statement.RefTable != nil {
name = statement.RefTable.Name
}
var aliasSplitter = "AS"
if statement.Engine.dialect.DBType() == core.MSSQL {
aliasSplitter = ""
}
return tableName{
name: name,
alias: statement.tableAlias,
aliasSplitter: aliasSplitter,
schema: statement.Engine.Dialect().URI().Schema,
}
statement.tableName = statement.Engine.TableName(bean, true)
return nil
}
// Auto generating update columnes and values according a struct
@ -492,28 +499,27 @@ func (statement *Statement) buildUpdates(bean interface{},
return colNames, args
}
func (statement *Statement) needTableName() bool {
func (statement *Statement) colsNeedTableName() bool {
return len(statement.JoinStr) > 0
}
func (statement *Statement) colName(col *core.Column, tableName string) string {
if statement.needTableName() {
var nm = tableName
if len(statement.TableAlias) > 0 {
nm = statement.TableAlias
}
return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
func (statement *Statement) writeColName(buf *strings.Builder, colName string) {
quotePair := statement.Engine.Dialect().Quote("")
if statement.colsNeedTableName() {
tbname := statement.getTableName()
quoteTo(buf, quotePair, tbname.withSchema())
buf.WriteByte('.')
}
return statement.Engine.Quote(col.Name)
quoteTo(buf, quotePair, colName)
}
// TableName return current tableName
func (statement *Statement) TableName() string {
if statement.AltTableName != "" {
return statement.AltTableName
// fullColName return a column name with schema/table name and quotes
func (statement *Statement) fullColName(colName string) string {
if statement.colsNeedTableName() {
tbname := statement.getTableName()
return tbname.withSchema() + "." + statement.Engine.Quote(colName)
}
return statement.tableName
return statement.Engine.Quote(colName)
}
// ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
@ -716,18 +722,24 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.setRefValue(v)
}
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
statement.altTableName = statement.Engine.TableName(tableNameOrBean, false)
return statement
}
// TableName return table name
func (statement *Statement) TableName() string {
if statement.altTableName != "" {
return statement.altTableName
}
if statement.RefTable != nil {
return statement.RefTable.Name
}
return ""
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf strings.Builder
@ -764,7 +776,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
tbName := statement.Engine.TableName(tablename, true)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
fmt.Println("------", tbName)
fmt.Fprintf(&buf, "%s ON %v", statement.Engine.Quote(tbName), condition)
}
statement.JoinStr = buf.String()
@ -815,17 +828,7 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ")
}
if statement.JoinStr != "" {
if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias)
} else {
buf.WriteString(statement.TableName())
}
buf.WriteString(".")
}
statement.Engine.QuoteTo(&buf, col.Name)
statement.writeColName(&buf, col.Name)
}
return buf.String()
@ -902,7 +905,7 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.tableAlias, addedTableName)
}
func (statement *Statement) mergeConds(bean interface{}) error {
@ -1060,11 +1063,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
fromStr += quote(statement.TableName())
}
if statement.TableAlias != "" {
if statement.tableAlias != "" {
if dialect.DBType() == core.ORACLE {
fromStr += " " + quote(statement.TableAlias)
fromStr += " " + quote(statement.tableAlias)
} else {
fromStr += " AS " + quote(statement.TableAlias)
fromStr += " AS " + quote(statement.tableAlias)
}
}
if statement.JoinStr != "" {
@ -1090,13 +1093,8 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
} else {
column = statement.RefTable.PKColumns()[0].Name
}
if statement.needTableName() {
if len(statement.TableAlias) > 0 {
column = statement.TableAlias + "." + column
} else {
column = statement.TableName() + "." + column
}
}
column = statement.fullColName(column)
var orderStr string
if needOrderBy && len(statement.OrderStr) > 0 {
@ -1171,7 +1169,7 @@ func (statement *Statement) processIDParam() error {
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
var colName = statement.fullColName(col.Name)
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
}
return nil

View File

@ -10,8 +10,8 @@ import (
"testing"
"time"
"xorm.io/core"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)
type tempUser struct {