From 3442ce81b6af606506a089db4bd061c6bff01ead Mon Sep 17 00:00:00 2001 From: Victor Gaydov Date: Thu, 28 Apr 2016 16:26:20 +0300 Subject: [PATCH 1/3] Always generate column names, don't use * even if join --- engine.go | 18 +++++- session.go | 70 +++++++------------- statement.go | 176 ++++++++++++++++++++++++++++++++------------------- 3 files changed, 151 insertions(+), 113 deletions(-) diff --git a/engine.go b/engine.go index ee7cb7bc..b7c39492 100644 --- a/engine.go +++ b/engine.go @@ -929,8 +929,16 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { fieldType := fieldValue.Type() if ormTagStr != "" { - col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)} + col = &core.Column{ + FieldName: t.Field(i).Name, + TableName: table.Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: core.TWOSIDES, + Indexes: make(map[string]bool), + } + tags := splitTag(ormTagStr) if len(tags) > 0 { @@ -952,6 +960,11 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { case reflect.Struct: parentTable := engine.mapType(fieldValue) for _, col := range parentTable.Columns() { + if t.Field(i).Anonymous { + col.TableName = parentTable.Name + } else { + col.TableName = engine.TableMapper.Obj2Table(t.Field(i).Name) + } col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName) table.AddColumn(col) } @@ -1133,6 +1146,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) + col.TableName = table.Name } if col.IsAutoIncrement { col.Nullable = false diff --git a/session.go b/session.go index 27ebf2da..f29d8201 100644 --- a/session.go +++ b/session.go @@ -1044,8 +1044,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sqlStr string var args []interface{} + session.Statement.OutTable = session.Engine.TableInfo(bean) + if session.Statement.RefTable == nil { - session.Statement.RefTable = session.Engine.TableInfo(bean) + session.Statement.RefTable = session.Statement.OutTable } if session.Statement.RawSQL == "" { @@ -1139,42 +1141,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() var table *core.Table - if session.Statement.RefTable == nil { - if sliceElementType.Kind() == reflect.Ptr { - if sliceElementType.Elem().Kind() == reflect.Struct { - pv := reflect.New(sliceElementType.Elem()) - table = session.Engine.autoMapType(pv.Elem()) - } else { - return errors.New("slice type") - } - } else if sliceElementType.Kind() == reflect.Struct { - pv := reflect.New(sliceElementType) + + if sliceElementType.Kind() == reflect.Ptr { + if sliceElementType.Elem().Kind() == reflect.Struct { + pv := reflect.New(sliceElementType.Elem()) table = session.Engine.autoMapType(pv.Elem()) } else { return errors.New("slice type") } - session.Statement.RefTable = table + } else if sliceElementType.Kind() == reflect.Struct { + pv := reflect.New(sliceElementType) + table = session.Engine.autoMapType(pv.Elem()) } else { - table = session.Statement.RefTable + return errors.New("slice type") + } + + session.Statement.OutTable = table + + if session.Statement.RefTable == nil { + session.Statement.RefTable = table } - var addedTableName = (len(session.Statement.JoinStr) > 0) if !session.Statement.noAutoCondition && len(condiBean) > 0 { - colNames, args := session.Statement.buildConditions(table, condiBean[0], true, true, false, true, addedTableName) + colNames, args := session.Statement.buildConditions( + table, condiBean[0], true, true, false, true, session.Statement.needTableName()) + session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } else { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://github.com/go-xorm/xorm/issues/179 - if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled - var colName = session.Engine.Quote(col.Name) - if addedTableName { - var nm = session.Statement.TableName() - if len(session.Statement.TableAlias) > 0 { - nm = session.Statement.TableAlias - } - colName = session.Engine.Quote(nm) + "." + colName - } + if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { + // tag "deleted" is enabled + var colName = session.Statement.colName(col, session.Statement.TableName()) session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')", colName, colName) } @@ -1183,28 +1182,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - var columnStr = session.Statement.ColumnStr - if len(session.Statement.selectStr) > 0 { - columnStr = session.Statement.selectStr - } else { - if session.Statement.JoinStr == "" { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = session.Statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = "*" - } - } - } - } + columnStr := session.Statement.genColumnStr() session.Statement.Params = append(session.Statement.joinArgs, append(session.Statement.Params, session.Statement.BeanArgs...)...) diff --git a/statement.go b/statement.go index 3dd5c60c..214906fa 100644 --- a/statement.go +++ b/statement.go @@ -40,6 +40,7 @@ type exprParam struct { // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table + OutTable *core.Table Engine *Engine Start int LimitN int @@ -54,6 +55,7 @@ type Statement struct { ColumnStr string selectStr string columnMap map[string]bool + tableMap map[string]bool useAllCols bool OmitStr string ConditionStr string @@ -85,6 +87,7 @@ type Statement struct { // Init reset all the statment's fields func (statement *Statement) Init() { statement.RefTable = nil + statement.OutTable = nil statement.Start = 0 statement.LimitN = 0 statement.WhereStr = "" @@ -98,6 +101,7 @@ func (statement *Statement) Init() { statement.ColumnStr = "" statement.OmitStr = "" statement.columnMap = make(map[string]bool) + statement.tableMap = make(map[string]bool) statement.ConditionStr = "" statement.AltTableName = "" statement.IdParam = nil @@ -141,7 +145,14 @@ func (statement *Statement) Sql(querystring string, args ...interface{}) *Statem // Alias set the table alias func (statement *Statement) Alias(alias string) *Statement { + if statement.TableName() != "" { + statement.tableMapDelete(statement.TableName()) + } + if statement.TableAlias != "" { + statement.tableMapDelete(statement.TableAlias) + } statement.TableAlias = alias + statement.tableMapAdd(alias) return statement } @@ -190,6 +201,9 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme // Table tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { + if statement.TableName() != "" { + statement.tableMapDelete(statement.TableName()) + } v := rValue(tableNameOrBean) t := v.Type() if t.Kind() == reflect.String { @@ -197,6 +211,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { } else if t.Kind() == reflect.Struct { statement.RefTable = statement.Engine.autoMapType(v) } + statement.tableMapAdd(statement.TableName()) return statement } @@ -440,21 +455,46 @@ func (statement *Statement) needTableName() bool { } func (statement *Statement) colName(col *core.Column, tableName string) string { - if statement.needTableName() { - var nm = tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias + return buildColName(statement.Engine, col, tableName, statement.TableAlias, + statement.outTableName(), statement.needTableName(), statement.tableMap) +} + +func buildColName(engine *Engine, col *core.Column, + mainTableName, mainTableAlias, outTableName string, needTableName bool, + knownTables map[string]bool) string { + var colTable string + + if needTableName { + var mainTable string + + if len(mainTableAlias) > 0 { + mainTable = mainTableAlias + } else { + mainTable = mainTableName + } + + if isKnownTable(engine, knownTables, mainTable, col.TableName) { + colTable = col.TableName + } else if isKnownTable(engine, knownTables, mainTable, outTableName) { + colTable = outTableName + } else { + colTable = "" } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) } - return statement.Engine.Quote(col.Name) + + if colTable != "" { + return engine.Quote(colTable) + "." + engine.Quote(col.Name) + } else { + return engine.Quote(col.Name) + } } // Auto generating conditions according a struct func buildConditions(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) ([]string, []interface{}) { + mustColumnMap map[string]bool, tableName, aliasName, outTableName string, + addedTableName bool, knownTables map[string]bool) ([]string, []interface{}) { var colNames []string var args = make([]interface{}, 0) for _, col := range table.Columns() { @@ -475,16 +515,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, continue } - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } + colName := buildColName( + engine, col, tableName, aliasName, outTableName, addedTableName, knownTables) fieldValuePtr, err := col.ValueOf(bean) if err != nil { @@ -688,6 +720,13 @@ func (statement *Statement) TableName() string { return "" } +func (statement *Statement) outTableName() string { + if statement.OutTable != nil { + return statement.OutTable.Name + } + return "" +} + // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" func (statement *Statement) Id(id interface{}) *Statement { idValue := reflect.ValueOf(id) @@ -979,12 +1018,15 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } + var refName string switch tablename.(type) { case []string: t := tablename.([]string) if len(t) > 1 { + refName = t[1] fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) } else if len(t) == 1 { + refName = t[0] fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) } case []interface{}: @@ -1003,18 +1045,22 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } } if l > 1 { + refName = fmt.Sprintf("%v", t[1]) fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), - statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) + statement.Engine.Quote(refName)) } else if l == 1 { + refName = table fmt.Fprintf(&buf, statement.Engine.Quote(table)) } default: - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) + refName = fmt.Sprintf("%v", tablename) + fmt.Fprintf(&buf, statement.Engine.Quote(refName)) } fmt.Fprintf(&buf, " ON %v", condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) + statement.tableMap[statement.Engine.Quote(strings.ToLower(refName))] = true return statement } @@ -1037,7 +1083,23 @@ func (statement *Statement) Unscoped() *Statement { } func (statement *Statement) genColumnStr() string { + if len(statement.selectStr) > 0 { + return statement.selectStr + } + + if len(statement.ColumnStr) > 0 { + return statement.ColumnStr + } + + if len(statement.GroupByStr) > 0 { + return statement.Engine.Quote( + strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + } + table := statement.RefTable + if statement.OutTable != nil { + table = statement.OutTable + } colNames := make([]string, 0) for _, col := range table.Columns() { if statement.OmitStr != "" { @@ -1049,26 +1111,12 @@ func (statement *Statement) genColumnStr() string { continue } - if statement.JoinStr != "" { - var name string - if statement.TableAlias != "" { - name = statement.Engine.Quote(statement.TableAlias) - } else { - name = statement.Engine.Quote(statement.TableName()) - } - name += "." + statement.Engine.Quote(col.Name) - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - colNames = append(colNames, "id() AS "+name) - } else { - colNames = append(colNames, name) - } + name := statement.colName(col, statement.TableName()) + + if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { + colNames = append(colNames, "id() AS "+name) } else { - name := statement.Engine.Quote(col.Name) - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - colNames = append(colNames, "id() AS "+name) - } else { - colNames = append(colNames, name) - } + colNames = append(colNames, name) } } return strings.Join(colNames, ", ") @@ -1135,38 +1183,15 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) table = statement.RefTable } - var addedTableName = (len(statement.JoinStr) > 0) - if !statement.noAutoCondition { - colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) + colNames, args := statement.buildConditions( + table, bean, true, true, false, true, statement.needTableName()) statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") statement.BeanArgs = args } - var columnStr string = statement.ColumnStr - if len(statement.selectStr) > 0 { - columnStr = statement.selectStr - } else { - // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = "*" - } - } - } - } + columnStr := statement.genColumnStr() statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)" return statement.genSelectSQL(columnStr), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) @@ -1195,7 +1220,7 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) { return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) + statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, statement.outTableName(), addedTableName, statement.tableMap) } func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { @@ -1347,3 +1372,24 @@ func (statement *Statement) processIdParam() { } } } + +func (statement *Statement) tableMapAdd(table string) { + tableName := statement.Engine.Quote(strings.ToLower(table)) + statement.tableMap[tableName] = true +} + +func (statement *Statement) tableMapDelete(table string) { + tableName := statement.Engine.Quote(strings.ToLower(table)) + delete(statement.tableMap, tableName) +} + +func isKnownTable(engine *Engine, tableMap map[string]bool, mainTable, table string) bool { + if len(table) > 0 { + cm := engine.Quote(strings.ToLower(mainTable)) + ct := engine.Quote(strings.ToLower(table)) + if ct == cm || tableMap[ct] { + return true + } + } + return false +} From ee6e17b4e39e69ea4c8a8f54f3dc55a0e13fc5c4 Mon Sep 17 00:00:00 2001 From: Victor Gaydov Date: Sun, 8 May 2016 15:11:29 +0300 Subject: [PATCH 2/3] Cleanup code and remove unnecessary wrappers in statement.go --- session.go | 6 +-- statement.go | 108 +++++++++++++++++++++++---------------------------- 2 files changed, 51 insertions(+), 63 deletions(-) diff --git a/session.go b/session.go index f29d8201..5e16c5b4 100644 --- a/session.go +++ b/session.go @@ -1173,9 +1173,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) // See https://github.com/go-xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled - var colName = session.Statement.colName(col, session.Statement.TableName()) - session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')", - colName, colName) + var colName = session.Statement.colName(col) + session.Statement.ConditionStr = fmt.Sprintf( + "(%v IS NULL OR %v = '0001-01-01 00:00:00')", colName, colName) } } diff --git a/statement.go b/statement.go index 214906fa..cd8a6a5e 100644 --- a/statement.go +++ b/statement.go @@ -454,47 +454,62 @@ func (statement *Statement) needTableName() bool { return len(statement.JoinStr) > 0 } -func (statement *Statement) colName(col *core.Column, tableName string) string { - return buildColName(statement.Engine, col, tableName, statement.TableAlias, - statement.outTableName(), statement.needTableName(), statement.tableMap) +func (statement *Statement) tableMapAdd(table string) { + tableName := statement.Engine.Quote(strings.ToLower(table)) + statement.tableMap[tableName] = true } -func buildColName(engine *Engine, col *core.Column, - mainTableName, mainTableAlias, outTableName string, needTableName bool, - knownTables map[string]bool) string { - var colTable string +func (statement *Statement) tableMapDelete(table string) { + tableName := statement.Engine.Quote(strings.ToLower(table)) + delete(statement.tableMap, tableName) +} - if needTableName { +func (statement *Statement) isKnownTable(table string) bool { + if len(table) > 0 { var mainTable string - if len(mainTableAlias) > 0 { - mainTable = mainTableAlias + if len(statement.TableAlias) > 0 { + mainTable = statement.TableAlias } else { - mainTable = mainTableName + mainTable = statement.TableName() } - if isKnownTable(engine, knownTables, mainTable, col.TableName) { + cm := statement.Engine.Quote(strings.ToLower(mainTable)) + ct := statement.Engine.Quote(strings.ToLower(table)) + + if ct == cm || statement.tableMap[ct] { + return true + } + } + return false +} + +func (statement *Statement) colName(col *core.Column) string { + var colTable string + + if statement.needTableName() { + if statement.isKnownTable(col.TableName) { colTable = col.TableName - } else if isKnownTable(engine, knownTables, mainTable, outTableName) { - colTable = outTableName + } else if statement.isKnownTable(statement.outTableName()) { + colTable = statement.outTableName() } else { colTable = "" } } if colTable != "" { - return engine.Quote(colTable) + "." + engine.Quote(col.Name) + return statement.Engine.Quote(colTable) + "." + statement.Engine.Quote(col.Name) } else { - return engine.Quote(col.Name) + return statement.Engine.Quote(col.Name) } } // Auto generating conditions according a struct -func buildConditions(engine *Engine, table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName, outTableName string, - addedTableName bool, knownTables map[string]bool) ([]string, []interface{}) { +func (statement *Statement) buildConditions( + table *core.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, + addedTableName bool) ([]string, []interface{}) { + engine := statement.Engine var colNames []string var args = make([]interface{}, 0) for _, col := range table.Columns() { @@ -515,8 +530,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, continue } - colName := buildColName( - engine, col, tableName, aliasName, outTableName, addedTableName, knownTables) + colName := statement.colName(col) fieldValuePtr, err := col.ValueOf(bean) if err != nil { @@ -524,9 +538,9 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, continue } - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - colNames = append(colNames, fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')", - colName, colName)) + if col.IsDeleted && !statement.unscoped { // tag "deleted" is enabled + colNames = append(colNames, fmt.Sprintf( + "(%v IS NULL OR %v = '0001-01-01 00:00:00')", colName, colName)) } fieldValue := *fieldValuePtr @@ -535,8 +549,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, } fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { + requiredField := statement.useAllCols + if b, ok := statement.mustColumnMap[strings.ToLower(col.Name)]; ok { if b { requiredField = true } else { @@ -564,7 +578,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, var val interface{} switch fieldType.Kind() { case reflect.Bool: - if allUseBool || requiredField { + if statement.allUseBool || requiredField { val = fieldValue.Interface() } else { // if a bool in a struct, it will not be as a condition because it default is false, @@ -1111,7 +1125,7 @@ func (statement *Statement) genColumnStr() string { continue } - name := statement.colName(col, statement.TableName()) + name := statement.colName(col) if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { colNames = append(colNames, "id() AS "+name) @@ -1218,21 +1232,16 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in return sql, []interface{}{} }*/ -func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) { - return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, statement.outTableName(), addedTableName, statement.tableMap) -} - func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { table := statement.Engine.TableInfo(bean) statement.RefTable = table - var addedTableName = (len(statement.JoinStr) > 0) - if !statement.noAutoCondition { - colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) + colNames, args := statement.buildConditions( + table, bean, true, true, false, true, statement.needTableName()) statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ") + statement.BeanArgs = args } @@ -1356,7 +1365,7 @@ func (statement *Statement) processIdParam() { if statement.IdParam != nil { if statement.Engine.dialect.DBType() != "ql" { for i, col := range statement.RefTable.PKColumns() { - var colName = statement.colName(col, statement.TableName()) + var colName = statement.colName(col) if i < len(*(statement.IdParam)) { statement.And(fmt.Sprintf("%v %s ?", colName, statement.Engine.dialect.EqStr()), (*(statement.IdParam))[i]) @@ -1372,24 +1381,3 @@ func (statement *Statement) processIdParam() { } } } - -func (statement *Statement) tableMapAdd(table string) { - tableName := statement.Engine.Quote(strings.ToLower(table)) - statement.tableMap[tableName] = true -} - -func (statement *Statement) tableMapDelete(table string) { - tableName := statement.Engine.Quote(strings.ToLower(table)) - delete(statement.tableMap, tableName) -} - -func isKnownTable(engine *Engine, tableMap map[string]bool, mainTable, table string) bool { - if len(table) > 0 { - cm := engine.Quote(strings.ToLower(mainTable)) - ct := engine.Quote(strings.ToLower(table)) - if ct == cm || tableMap[ct] { - return true - } - } - return false -} From 7aed2a253b10f122b3bde3df08cb5085de179924 Mon Sep 17 00:00:00 2001 From: Victor Gaydov Date: Thu, 12 May 2016 23:29:12 +0300 Subject: [PATCH 3/3] Use alias name for field qualifiers when available --- statement.go | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/statement.go b/statement.go index cd8a6a5e..2dac5113 100644 --- a/statement.go +++ b/statement.go @@ -55,7 +55,7 @@ type Statement struct { ColumnStr string selectStr string columnMap map[string]bool - tableMap map[string]bool + tableMap map[string]string useAllCols bool OmitStr string ConditionStr string @@ -101,7 +101,7 @@ func (statement *Statement) Init() { statement.ColumnStr = "" statement.OmitStr = "" statement.columnMap = make(map[string]bool) - statement.tableMap = make(map[string]bool) + statement.tableMap = make(map[string]string) statement.ConditionStr = "" statement.AltTableName = "" statement.IdParam = nil @@ -201,7 +201,7 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme // Table tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - if statement.TableName() != "" { + if statement.TableAlias == "" && statement.TableName() != "" { statement.tableMapDelete(statement.TableName()) } v := rValue(tableNameOrBean) @@ -211,7 +211,9 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { } else if t.Kind() == reflect.Struct { statement.RefTable = statement.Engine.autoMapType(v) } - statement.tableMapAdd(statement.TableName()) + if statement.TableAlias == "" { + statement.tableMapAdd(statement.TableName()) + } return statement } @@ -456,7 +458,7 @@ func (statement *Statement) needTableName() bool { func (statement *Statement) tableMapAdd(table string) { tableName := statement.Engine.Quote(strings.ToLower(table)) - statement.tableMap[tableName] = true + statement.tableMap[tableName] = table } func (statement *Statement) tableMapDelete(table string) { @@ -464,7 +466,7 @@ func (statement *Statement) tableMapDelete(table string) { delete(statement.tableMap, tableName) } -func (statement *Statement) isKnownTable(table string) bool { +func (statement *Statement) isKnownTable(table string) (string, bool) { if len(table) > 0 { var mainTable string @@ -477,21 +479,25 @@ func (statement *Statement) isKnownTable(table string) bool { cm := statement.Engine.Quote(strings.ToLower(mainTable)) ct := statement.Engine.Quote(strings.ToLower(table)) - if ct == cm || statement.tableMap[ct] { - return true + if name, ok := statement.tableMap[ct]; ok { + return name, true + } + + if ct == cm { + return mainTable, true } } - return false + return "", false } func (statement *Statement) colName(col *core.Column) string { var colTable string if statement.needTableName() { - if statement.isKnownTable(col.TableName) { - colTable = col.TableName - } else if statement.isKnownTable(statement.outTableName()) { - colTable = statement.outTableName() + if name, ok := statement.isKnownTable(col.TableName); ok { + colTable = name + } else if name, ok := statement.isKnownTable(statement.outTableName()); ok { + colTable = name } else { colTable = "" } @@ -1074,7 +1080,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, " ON %v", condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) - statement.tableMap[statement.Engine.Quote(strings.ToLower(refName))] = true + statement.tableMapAdd(refName) return statement }