diff --git a/engine.go b/engine.go index a7e52ea4..495cb1d4 100644 --- a/engine.go +++ b/engine.go @@ -215,14 +215,15 @@ func quoteTo(buf *strings.Builder, quotePair string, value string) { _, _ = buf.WriteString(value) return } - - prefix, suffix := quotePair[0], quotePair[1] + prefix, suffix := quotePair[0], quotePair[1] + lastCh := 0 // 0 prefix, 1 char, 2 suffix i := 0 for i < len(value) { // start of a token; might be already quoted if value[i] == '.' { _ = buf.WriteByte('.') + lastCh = 1 i++ } else if value[i] == prefix || value[i] == '`' { // Has quotes; skip/normalize `name` to prefix+name+sufix @@ -234,18 +235,37 @@ func quoteTo(buf *strings.Builder, quotePair string, value string) { } i++ _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != ch; i++ { + lastCh = 0 + for ; i < len(value) && value[i] != ch && value[i] != ' '; i++ { _ = buf.WriteByte(value[i]) + lastCh = 1 } _ = buf.WriteByte(suffix) + lastCh = 2 i++ + } else if value[i] == ' ' { + if lastCh != 2 { + _ = buf.WriteByte(suffix) + lastCh = 2 + } + + // a AS b or a b + for ; i < len(value); i++ { + if value[i] != ' ' && value[i-1] == ' ' && (len(value) > i+1 && !strings.EqualFold(value[i:i+2], "AS")) { + break + } + _ = buf.WriteByte(value[i]) + lastCh = 1 + } } else { // Requires quotes _ = buf.WriteByte(prefix) - for ; i < len(value) && value[i] != '.'; i++ { + for ; i < len(value) && value[i] != '.' && value[i] != ' '; i++ { _ = buf.WriteByte(value[i]) + lastCh = 1 } _ = buf.WriteByte(suffix) + lastCh = 2 } } } @@ -918,10 +938,18 @@ var ( ) func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { - t := v.Type() table := core.NewEmptyTable() + table.Name = tbNameForMap(engine.TableMapper, v) + + t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + table.Type = t - table.Name = engine.tbNameForMap(v) + + fmt.Println("======", table.Name) var idFieldColName string var hasCacheTag, hasNoCacheTag bool diff --git a/engine_table.go b/engine_table.go index eb5aa850..0e7bac24 100644 --- a/engine_table.go +++ b/engine_table.go @@ -27,46 +27,61 @@ func (engine *Engine) tbNameWithSchema(v string) string { // TableName returns table name with schema prefix if has func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string { - tbName := engine.tbNameNoSchema(bean) + tbName, _ := newTableName(engine.TableMapper, bean) if len(includeSchema) > 0 && includeSchema[0] { - tbName = engine.tbNameWithSchema(tbName) + tbName.schema = engine.dialect.URI().Schema + return tbName.withSchema() } - return tbName + return tbName.withNoSchema() } // tbName get some table's table name func (session *Session) tbNameNoSchema(table *core.Table) string { - if len(session.statement.AltTableName) > 0 { - return session.statement.AltTableName + if len(session.statement.altTableName) > 0 { + return session.statement.altTableName } return table.Name } -func (engine *Engine) tbNameForMap(v reflect.Value) string { +func tbNameForMap(mapper core.IMapper, v reflect.Value) string { + if t, ok := v.Interface().(TableName); ok { + return t.TableName() + } if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() } if v.Kind() == reflect.Ptr { v = v.Elem() + if t, ok := v.Interface().(TableName); ok { + return t.TableName() + } if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() } } - return engine.TableMapper.Obj2Table(v.Type().Name()) + return mapper.Obj2Table(v.Type().Name()) } -func (engine *Engine) tbNameNoSchema(tablename interface{}) string { +type tableName struct { + name string + schema string + alias string + aliasSplitter string +} + +func newTableName(mapper core.IMapper, tablename interface{}) (tableName, error) { switch tablename.(type) { case []string: t := tablename.([]string) if len(t) > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) + return tableName{name: t[0], alias: t[1]}, nil } else if len(t) == 1 { - return engine.Quote(t[0]) + return tableName{name: t[0]}, nil } + return tableName{}, ErrTableNotFound case []interface{}: t := tablename.([]interface{}) l := len(t) @@ -82,32 +97,56 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string { v := rValue(f) t := v.Type() if t.Kind() == reflect.Struct { - table = engine.tbNameForMap(v) + table = tbNameForMap(mapper, v) } else { - table = engine.Quote(fmt.Sprintf("%v", f)) + table = fmt.Sprintf("%v", f) } } } if l > 1 { - return fmt.Sprintf("%v AS %v", engine.Quote(table), - engine.Quote(fmt.Sprintf("%v", t[1]))) + return tableName{name: table, alias: fmt.Sprintf("%v", t[1])}, nil } else if l == 1 { - return engine.Quote(table) + return tableName{name: table}, nil } case TableName: - return tablename.(TableName).TableName() + fmt.Println("+++++++++++++++++++++++++", tablename.(TableName).TableName()) + return tableName{name: tablename.(TableName).TableName()}, nil case string: - return tablename.(string) + return tableName{name: tablename.(string)}, nil case reflect.Value: v := tablename.(reflect.Value) - return engine.tbNameForMap(v) + return tableName{name: tbNameForMap(mapper, v)}, nil default: v := rValue(tablename) t := v.Type() if t.Kind() == reflect.Struct { - return engine.tbNameForMap(v) + return tableName{name: tbNameForMap(mapper, v)}, nil } - return engine.Quote(fmt.Sprintf("%v", tablename)) + return tableName{name: fmt.Sprintf("%v", tablename)}, nil } - return "" + return tableName{}, ErrTableNotFound +} + +func (t tableName) withSchema() string { + if t.schema == "" { + return t.withNoSchema() + } + + if t.alias != "" { + if t.aliasSplitter != "" { + return fmt.Sprintf("%s.%s %s %s", t.schema, t.name, t.aliasSplitter, t.alias) + } + return fmt.Sprintf("%s.%s %s", t.schema, t.name, t.alias) + } + return fmt.Sprintf("%s.%s", t.schema, t.name) +} + +func (t tableName) withNoSchema() string { + if t.alias != "" { + if t.aliasSplitter != "" { + return fmt.Sprintf("%s %s %s", t.name, t.aliasSplitter, t.alias) + } + return fmt.Sprintf("%s %s", t.name, t.alias) + } + return t.name } diff --git a/session_find.go b/session_find.go index e16ae54c..44acb931 100644 --- a/session_find.go +++ b/session_find.go @@ -115,8 +115,8 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) var colName = session.engine.Quote(col.Name) if addedTableName { var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias + if len(session.statement.tableAlias) > 0 { + nm = session.statement.tableAlias } colName = session.engine.Quote(nm) + "." + colName } diff --git a/session_find_test.go b/session_find_test.go index f805f06e..1211ab37 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestJoinLimit(t *testing.T) { @@ -801,3 +801,25 @@ func TestFindJoin(t *testing.T) { Where("scene_item.type=?", 3).Or("device_user_privrels.user_id=?", 339).Find(&scenes) assert.NoError(t, err) } + +func TestJoinReverseWord(t *testing.T) { + type JoinReverseWord struct { + Id int64 + Name string + } + + type JoinReverseWord2 struct { + Id int64 + UserId int64 `xorm:"index"` + Age int + } + + assert.NoError(t, prepareEngine()) + err := testEngine.Table("order").Sync2(new(JoinReverseWord)) + assert.NoError(t, err) + assertSync(t, new(JoinReverseWord2)) + + var j2 []JoinReverseWord2 + err = testEngine.Join("INNER", "order", "`join_reverse_word2`.user_id=`order`.id").Find(&j2) + assert.NoError(t, err) +} diff --git a/session_get_test.go b/session_get_test.go index fcef992e..e3f40515 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -492,6 +492,8 @@ func TestGetCustomTableInterface(t *testing.T) { assert.NoError(t, err) assert.True(t, exist) + assert.EqualValues(t, getCustomTableName, testEngine.TableInfo(new(MyGetCustomTableImpletation)).Name) + _, err = testEngine.Insert(&MyGetCustomTableImpletation{ Name: "xlw", }) diff --git a/session_insert.go b/session_insert.go index 5f8f7e1e..1d0af00a 100644 --- a/session_insert.go +++ b/session_insert.go @@ -324,7 +324,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err := session.statement.setRefBean(bean); err != nil { return 0, err } - if len(session.statement.TableName()) <= 0 { + var tableName = session.statement.TableName() + fmt.Println("------", tableName) + if len(tableName) <= 0 { return 0, ErrTableNotFound } @@ -351,7 +353,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { colPlaces = colPlaces[0 : len(colPlaces)-2] } - var tableName = session.statement.TableName() var output string if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 { output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) diff --git a/session_schema.go b/session_schema.go index 5e576c29..5264bbd3 100644 --- a/session_schema.go +++ b/session_schema.go @@ -245,9 +245,13 @@ func (session *Session) Sync2(beans ...interface{}) error { if err != nil { return err } - var tbName string - if len(session.statement.AltTableName) > 0 { - tbName = session.statement.AltTableName + + var ( + tbName string + altTableName = session.statement.altTableName + ) + if len(altTableName) > 0 { + tbName = altTableName } else { tbName = engine.TableName(bean) } @@ -298,7 +302,7 @@ func (session *Session) Sync2(beans ...interface{}) error { // column is not exist on table if oriCol == nil { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.altTableName = altTableName if err = session.addColumn(col.Name); err != nil { return err } @@ -406,11 +410,11 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == core.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.altTableName = altTableName err = session.addUnique(tbNameWithSchema, name) } else if index.Type == core.IndexType { session.statement.RefTable = table - session.statement.tableName = tbNameWithSchema + session.statement.altTableName = altTableName err = session.addIndex(tbNameWithSchema, name) } if err != nil { diff --git a/session_update.go b/session_update.go index 22d516e7..967f4b15 100644 --- a/session_update.go +++ b/session_update.go @@ -389,13 +389,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var tableAlias = session.engine.Quote(tableName) var fromSQL string - if session.statement.TableAlias != "" { + if session.statement.tableAlias != "" { switch session.engine.dialect.DBType() { case core.MSSQL: - fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) - tableAlias = session.statement.TableAlias + fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.tableAlias) + tableAlias = session.statement.tableAlias default: - tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias) + tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.tableAlias) } } diff --git a/statement.go b/statement.go index 67e35213..18c91224 100644 --- a/statement.go +++ b/statement.go @@ -31,8 +31,7 @@ type Statement struct { selectStr string useAllCols bool OmitStr string - AltTableName string - tableName string + altTableName string RawSQL string RawParams []interface{} UseCascade bool @@ -44,7 +43,7 @@ type Statement struct { noAutoCondition bool IsDistinct bool IsForUpdate bool - TableAlias string + tableAlias string allUseBool bool checkVersion bool unscoped bool @@ -76,8 +75,7 @@ func (statement *Statement) Init() { statement.OmitStr = "" statement.columnMap = columnMap{} statement.omitColumnMap = columnMap{} - statement.AltTableName = "" - statement.tableName = "" + statement.altTableName = "" statement.idParam = nil statement.RawSQL = "" statement.RawParams = make([]interface{}, 0) @@ -86,7 +84,7 @@ func (statement *Statement) Init() { statement.noAutoCondition = false statement.IsDistinct = false statement.IsForUpdate = false - statement.TableAlias = "" + statement.tableAlias = "" statement.selectStr = "" statement.allUseBool = false statement.useAllCols = false @@ -114,7 +112,7 @@ func (statement *Statement) NoAutoCondition(no ...bool) *Statement { // Alias set the table alias func (statement *Statement) Alias(alias string) *Statement { - statement.TableAlias = alias + statement.tableAlias = alias return statement } @@ -209,22 +207,31 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement func (statement *Statement) setRefValue(v reflect.Value) error { var err error - statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v)) - if err != nil { - return err - } - statement.tableName = statement.Engine.TableName(v, true) - return nil + statement.RefTable, err = statement.Engine.autoMapType(v) + return err } func (statement *Statement) setRefBean(bean interface{}) error { - var err error - statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) - if err != nil { - return err + return statement.setRefValue(reflect.ValueOf(bean)) +} + +func (statement *Statement) getTableName() tableName { + var name = statement.altTableName + if name == "" && statement.RefTable != nil { + name = statement.RefTable.Name + } + + var aliasSplitter = "AS" + if statement.Engine.dialect.DBType() == core.MSSQL { + aliasSplitter = "" + } + + return tableName{ + name: name, + alias: statement.tableAlias, + aliasSplitter: aliasSplitter, + schema: statement.Engine.Dialect().URI().Schema, } - statement.tableName = statement.Engine.TableName(bean, true) - return nil } // Auto generating update columnes and values according a struct @@ -492,28 +499,27 @@ func (statement *Statement) buildUpdates(bean interface{}, return colNames, args } -func (statement *Statement) needTableName() bool { +func (statement *Statement) colsNeedTableName() bool { return len(statement.JoinStr) > 0 } -func (statement *Statement) colName(col *core.Column, tableName string) string { - if statement.needTableName() { - var nm = tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias - } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) +func (statement *Statement) writeColName(buf *strings.Builder, colName string) { + quotePair := statement.Engine.Dialect().Quote("") + if statement.colsNeedTableName() { + tbname := statement.getTableName() + quoteTo(buf, quotePair, tbname.withSchema()) + buf.WriteByte('.') } - return statement.Engine.Quote(col.Name) + quoteTo(buf, quotePair, colName) } -// TableName return current tableName -func (statement *Statement) TableName() string { - if statement.AltTableName != "" { - return statement.AltTableName +// fullColName return a column name with schema/table name and quotes +func (statement *Statement) fullColName(colName string) string { + if statement.colsNeedTableName() { + tbname := statement.getTableName() + return tbname.withSchema() + "." + statement.Engine.Quote(colName) } - - return statement.tableName + return statement.Engine.Quote(colName) } // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?" @@ -716,18 +722,24 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { v := rValue(tableNameOrBean) t := v.Type() if t.Kind() == reflect.Struct { - var err error - statement.RefTable, err = statement.Engine.autoMapType(v) - if err != nil { - statement.Engine.logger.Error(err) - return statement - } + statement.setRefValue(v) } - statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true) + statement.altTableName = statement.Engine.TableName(tableNameOrBean, false) return statement } +// TableName return table name +func (statement *Statement) TableName() string { + if statement.altTableName != "" { + return statement.altTableName + } + if statement.RefTable != nil { + return statement.RefTable.Name + } + return "" +} + // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { var buf strings.Builder @@ -764,7 +776,8 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: tbName := statement.Engine.TableName(tablename, true) - fmt.Fprintf(&buf, "%s ON %v", tbName, condition) + fmt.Println("------", tbName) + fmt.Fprintf(&buf, "%s ON %v", statement.Engine.Quote(tbName), condition) } statement.JoinStr = buf.String() @@ -815,17 +828,7 @@ func (statement *Statement) genColumnStr() string { buf.WriteString(", ") } - if statement.JoinStr != "" { - if statement.TableAlias != "" { - buf.WriteString(statement.TableAlias) - } else { - buf.WriteString(statement.TableName()) - } - - buf.WriteString(".") - } - - statement.Engine.QuoteTo(&buf, col.Name) + statement.writeColName(&buf, col.Name) } return buf.String() @@ -902,7 +905,7 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) + statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.tableAlias, addedTableName) } func (statement *Statement) mergeConds(bean interface{}) error { @@ -1060,11 +1063,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n fromStr += quote(statement.TableName()) } - if statement.TableAlias != "" { + if statement.tableAlias != "" { if dialect.DBType() == core.ORACLE { - fromStr += " " + quote(statement.TableAlias) + fromStr += " " + quote(statement.tableAlias) } else { - fromStr += " AS " + quote(statement.TableAlias) + fromStr += " AS " + quote(statement.tableAlias) } } if statement.JoinStr != "" { @@ -1090,13 +1093,8 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n } else { column = statement.RefTable.PKColumns()[0].Name } - if statement.needTableName() { - if len(statement.TableAlias) > 0 { - column = statement.TableAlias + "." + column - } else { - column = statement.TableName() + "." + column - } - } + + column = statement.fullColName(column) var orderStr string if needOrderBy && len(statement.OrderStr) > 0 { @@ -1171,7 +1169,7 @@ func (statement *Statement) processIDParam() error { } for i, col := range statement.RefTable.PKColumns() { - var colName = statement.colName(col, statement.TableName()) + var colName = statement.fullColName(col.Name) statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) } return nil diff --git a/tag_extends_test.go b/tag_extends_test.go index 5a8031f0..b23b8181 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) type tempUser struct {