From 3442ce81b6af606506a089db4bd061c6bff01ead Mon Sep 17 00:00:00 2001 From: Victor Gaydov Date: Thu, 28 Apr 2016 16:26:20 +0300 Subject: [PATCH] Always generate column names, don't use * even if join --- engine.go | 18 +++++- session.go | 70 +++++++------------- statement.go | 176 ++++++++++++++++++++++++++++++++------------------- 3 files changed, 151 insertions(+), 113 deletions(-) diff --git a/engine.go b/engine.go index ee7cb7bc..b7c39492 100644 --- a/engine.go +++ b/engine.go @@ -929,8 +929,16 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { fieldType := fieldValue.Type() if ormTagStr != "" { - col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)} + col = &core.Column{ + FieldName: t.Field(i).Name, + TableName: table.Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: core.TWOSIDES, + Indexes: make(map[string]bool), + } + tags := splitTag(ormTagStr) if len(tags) > 0 { @@ -952,6 +960,11 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { case reflect.Struct: parentTable := engine.mapType(fieldValue) for _, col := range parentTable.Columns() { + if t.Field(i).Anonymous { + col.TableName = parentTable.Name + } else { + col.TableName = engine.TableMapper.Obj2Table(t.Field(i).Name) + } col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName) table.AddColumn(col) } @@ -1133,6 +1146,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) + col.TableName = table.Name } if col.IsAutoIncrement { col.Nullable = false diff --git a/session.go b/session.go index 27ebf2da..f29d8201 100644 --- a/session.go +++ b/session.go @@ -1044,8 +1044,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sqlStr string var args []interface{} + session.Statement.OutTable = session.Engine.TableInfo(bean) + if session.Statement.RefTable == nil { - session.Statement.RefTable = session.Engine.TableInfo(bean) + session.Statement.RefTable = session.Statement.OutTable } if session.Statement.RawSQL == "" { @@ -1139,42 +1141,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() var table *core.Table - if session.Statement.RefTable == nil { - if sliceElementType.Kind() == reflect.Ptr { - if sliceElementType.Elem().Kind() == reflect.Struct { - pv := reflect.New(sliceElementType.Elem()) - table = session.Engine.autoMapType(pv.Elem()) - } else { - return errors.New("slice type") - } - } else if sliceElementType.Kind() == reflect.Struct { - pv := reflect.New(sliceElementType) + + if sliceElementType.Kind() == reflect.Ptr { + if sliceElementType.Elem().Kind() == reflect.Struct { + pv := reflect.New(sliceElementType.Elem()) table = session.Engine.autoMapType(pv.Elem()) } else { return errors.New("slice type") } - session.Statement.RefTable = table + } else if sliceElementType.Kind() == reflect.Struct { + pv := reflect.New(sliceElementType) + table = session.Engine.autoMapType(pv.Elem()) } else { - table = session.Statement.RefTable + return errors.New("slice type") + } + + session.Statement.OutTable = table + + if session.Statement.RefTable == nil { + session.Statement.RefTable = table } - var addedTableName = (len(session.Statement.JoinStr) > 0) if !session.Statement.noAutoCondition && len(condiBean) > 0 { - colNames, args := session.Statement.buildConditions(table, condiBean[0], true, true, false, true, addedTableName) + colNames, args := session.Statement.buildConditions( + table, condiBean[0], true, true, false, true, session.Statement.needTableName()) + session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } else { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://github.com/go-xorm/xorm/issues/179 - if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled - var colName = session.Engine.Quote(col.Name) - if addedTableName { - var nm = session.Statement.TableName() - if len(session.Statement.TableAlias) > 0 { - nm = session.Statement.TableAlias - } - colName = session.Engine.Quote(nm) + "." + colName - } + if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { + // tag "deleted" is enabled + var colName = session.Statement.colName(col, session.Statement.TableName()) session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')", colName, colName) } @@ -1183,28 +1182,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - var columnStr = session.Statement.ColumnStr - if len(session.Statement.selectStr) > 0 { - columnStr = session.Statement.selectStr - } else { - if session.Statement.JoinStr == "" { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = session.Statement.genColumnStr() - } - } - } else { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = "*" - } - } - } - } + columnStr := session.Statement.genColumnStr() session.Statement.Params = append(session.Statement.joinArgs, append(session.Statement.Params, session.Statement.BeanArgs...)...) diff --git a/statement.go b/statement.go index 3dd5c60c..214906fa 100644 --- a/statement.go +++ b/statement.go @@ -40,6 +40,7 @@ type exprParam struct { // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table + OutTable *core.Table Engine *Engine Start int LimitN int @@ -54,6 +55,7 @@ type Statement struct { ColumnStr string selectStr string columnMap map[string]bool + tableMap map[string]bool useAllCols bool OmitStr string ConditionStr string @@ -85,6 +87,7 @@ type Statement struct { // Init reset all the statment's fields func (statement *Statement) Init() { statement.RefTable = nil + statement.OutTable = nil statement.Start = 0 statement.LimitN = 0 statement.WhereStr = "" @@ -98,6 +101,7 @@ func (statement *Statement) Init() { statement.ColumnStr = "" statement.OmitStr = "" statement.columnMap = make(map[string]bool) + statement.tableMap = make(map[string]bool) statement.ConditionStr = "" statement.AltTableName = "" statement.IdParam = nil @@ -141,7 +145,14 @@ func (statement *Statement) Sql(querystring string, args ...interface{}) *Statem // Alias set the table alias func (statement *Statement) Alias(alias string) *Statement { + if statement.TableName() != "" { + statement.tableMapDelete(statement.TableName()) + } + if statement.TableAlias != "" { + statement.tableMapDelete(statement.TableAlias) + } statement.TableAlias = alias + statement.tableMapAdd(alias) return statement } @@ -190,6 +201,9 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme // Table tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { + if statement.TableName() != "" { + statement.tableMapDelete(statement.TableName()) + } v := rValue(tableNameOrBean) t := v.Type() if t.Kind() == reflect.String { @@ -197,6 +211,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { } else if t.Kind() == reflect.Struct { statement.RefTable = statement.Engine.autoMapType(v) } + statement.tableMapAdd(statement.TableName()) return statement } @@ -440,21 +455,46 @@ func (statement *Statement) needTableName() bool { } func (statement *Statement) colName(col *core.Column, tableName string) string { - if statement.needTableName() { - var nm = tableName - if len(statement.TableAlias) > 0 { - nm = statement.TableAlias + return buildColName(statement.Engine, col, tableName, statement.TableAlias, + statement.outTableName(), statement.needTableName(), statement.tableMap) +} + +func buildColName(engine *Engine, col *core.Column, + mainTableName, mainTableAlias, outTableName string, needTableName bool, + knownTables map[string]bool) string { + var colTable string + + if needTableName { + var mainTable string + + if len(mainTableAlias) > 0 { + mainTable = mainTableAlias + } else { + mainTable = mainTableName + } + + if isKnownTable(engine, knownTables, mainTable, col.TableName) { + colTable = col.TableName + } else if isKnownTable(engine, knownTables, mainTable, outTableName) { + colTable = outTableName + } else { + colTable = "" } - return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) } - return statement.Engine.Quote(col.Name) + + if colTable != "" { + return engine.Quote(colTable) + "." + engine.Quote(col.Name) + } else { + return engine.Quote(col.Name) + } } // Auto generating conditions according a struct func buildConditions(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) ([]string, []interface{}) { + mustColumnMap map[string]bool, tableName, aliasName, outTableName string, + addedTableName bool, knownTables map[string]bool) ([]string, []interface{}) { var colNames []string var args = make([]interface{}, 0) for _, col := range table.Columns() { @@ -475,16 +515,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, continue } - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } + colName := buildColName( + engine, col, tableName, aliasName, outTableName, addedTableName, knownTables) fieldValuePtr, err := col.ValueOf(bean) if err != nil { @@ -688,6 +720,13 @@ func (statement *Statement) TableName() string { return "" } +func (statement *Statement) outTableName() string { + if statement.OutTable != nil { + return statement.OutTable.Name + } + return "" +} + // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" func (statement *Statement) Id(id interface{}) *Statement { idValue := reflect.ValueOf(id) @@ -979,12 +1018,15 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } + var refName string switch tablename.(type) { case []string: t := tablename.([]string) if len(t) > 1 { + refName = t[1] fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) } else if len(t) == 1 { + refName = t[0] fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) } case []interface{}: @@ -1003,18 +1045,22 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition } } if l > 1 { + refName = fmt.Sprintf("%v", t[1]) fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), - statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) + statement.Engine.Quote(refName)) } else if l == 1 { + refName = table fmt.Fprintf(&buf, statement.Engine.Quote(table)) } default: - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) + refName = fmt.Sprintf("%v", tablename) + fmt.Fprintf(&buf, statement.Engine.Quote(refName)) } fmt.Fprintf(&buf, " ON %v", condition) statement.JoinStr = buf.String() statement.joinArgs = append(statement.joinArgs, args...) + statement.tableMap[statement.Engine.Quote(strings.ToLower(refName))] = true return statement } @@ -1037,7 +1083,23 @@ func (statement *Statement) Unscoped() *Statement { } func (statement *Statement) genColumnStr() string { + if len(statement.selectStr) > 0 { + return statement.selectStr + } + + if len(statement.ColumnStr) > 0 { + return statement.ColumnStr + } + + if len(statement.GroupByStr) > 0 { + return statement.Engine.Quote( + strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + } + table := statement.RefTable + if statement.OutTable != nil { + table = statement.OutTable + } colNames := make([]string, 0) for _, col := range table.Columns() { if statement.OmitStr != "" { @@ -1049,26 +1111,12 @@ func (statement *Statement) genColumnStr() string { continue } - if statement.JoinStr != "" { - var name string - if statement.TableAlias != "" { - name = statement.Engine.Quote(statement.TableAlias) - } else { - name = statement.Engine.Quote(statement.TableName()) - } - name += "." + statement.Engine.Quote(col.Name) - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - colNames = append(colNames, "id() AS "+name) - } else { - colNames = append(colNames, name) - } + name := statement.colName(col, statement.TableName()) + + if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { + colNames = append(colNames, "id() AS "+name) } else { - name := statement.Engine.Quote(col.Name) - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - colNames = append(colNames, "id() AS "+name) - } else { - colNames = append(colNames, name) - } + colNames = append(colNames, name) } } return strings.Join(colNames, ", ") @@ -1135,38 +1183,15 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) table = statement.RefTable } - var addedTableName = (len(statement.JoinStr) > 0) - if !statement.noAutoCondition { - colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) + colNames, args := statement.buildConditions( + table, bean, true, true, false, true, statement.needTableName()) statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") statement.BeanArgs = args } - var columnStr string = statement.ColumnStr - if len(statement.selectStr) > 0 { - columnStr = statement.selectStr - } else { - // TODO: always generate column names, not use * even if join - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = statement.genColumnStr() - } - } - } else { - if len(columnStr) == 0 { - if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = "*" - } - } - } - } + columnStr := statement.genColumnStr() statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)" return statement.genSelectSQL(columnStr), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) @@ -1195,7 +1220,7 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) { return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) + statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, statement.outTableName(), addedTableName, statement.tableMap) } func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { @@ -1347,3 +1372,24 @@ func (statement *Statement) processIdParam() { } } } + +func (statement *Statement) tableMapAdd(table string) { + tableName := statement.Engine.Quote(strings.ToLower(table)) + statement.tableMap[tableName] = true +} + +func (statement *Statement) tableMapDelete(table string) { + tableName := statement.Engine.Quote(strings.ToLower(table)) + delete(statement.tableMap, tableName) +} + +func isKnownTable(engine *Engine, tableMap map[string]bool, mainTable, table string) bool { + if len(table) > 0 { + cm := engine.Quote(strings.ToLower(mainTable)) + ct := engine.Quote(strings.ToLower(table)) + if ct == cm || tableMap[ct] { + return true + } + } + return false +}