Always generate column names, don't use * even if join

This commit is contained in:
Victor Gaydov 2016-04-28 16:26:20 +03:00
parent a01eeeddbc
commit 3442ce81b6
3 changed files with 151 additions and 113 deletions

View File

@ -929,8 +929,16 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
if ormTagStr != "" { if ormTagStr != "" {
col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, col = &core.Column{
IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)} FieldName: t.Field(i).Name,
TableName: table.Name,
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,
MapType: core.TWOSIDES,
Indexes: make(map[string]bool),
}
tags := splitTag(ormTagStr) tags := splitTag(ormTagStr)
if len(tags) > 0 { if len(tags) > 0 {
@ -952,6 +960,11 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
case reflect.Struct: case reflect.Struct:
parentTable := engine.mapType(fieldValue) parentTable := engine.mapType(fieldValue)
for _, col := range parentTable.Columns() { for _, col := range parentTable.Columns() {
if t.Field(i).Anonymous {
col.TableName = parentTable.Name
} else {
col.TableName = engine.TableMapper.Obj2Table(t.Field(i).Name)
}
col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
table.AddColumn(col) table.AddColumn(col)
} }
@ -1133,6 +1146,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name),
t.Field(i).Name, sqlType, sqlType.DefaultLength, t.Field(i).Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true) sqlType.DefaultLength2, true)
col.TableName = table.Name
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
col.Nullable = false col.Nullable = false

View File

@ -1044,8 +1044,10 @@ func (session *Session) Get(bean interface{}) (bool, error) {
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
session.Statement.OutTable = session.Engine.TableInfo(bean)
if session.Statement.RefTable == nil { if session.Statement.RefTable == nil {
session.Statement.RefTable = session.Engine.TableInfo(bean) session.Statement.RefTable = session.Statement.OutTable
} }
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
@ -1139,42 +1141,39 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var table *core.Table var table *core.Table
if session.Statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
pv := reflect.New(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
table = session.Engine.autoMapType(pv.Elem())
} else {
return errors.New("slice type")
}
} else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType)
table = session.Engine.autoMapType(pv.Elem()) table = session.Engine.autoMapType(pv.Elem())
} else { } else {
return errors.New("slice type") return errors.New("slice type")
} }
session.Statement.RefTable = table } else if sliceElementType.Kind() == reflect.Struct {
pv := reflect.New(sliceElementType)
table = session.Engine.autoMapType(pv.Elem())
} else { } else {
table = session.Statement.RefTable return errors.New("slice type")
}
session.Statement.OutTable = table
if session.Statement.RefTable == nil {
session.Statement.RefTable = table
} }
var addedTableName = (len(session.Statement.JoinStr) > 0)
if !session.Statement.noAutoCondition && len(condiBean) > 0 { if !session.Statement.noAutoCondition && len(condiBean) > 0 {
colNames, args := session.Statement.buildConditions(table, condiBean[0], true, true, false, true, addedTableName) colNames, args := session.Statement.buildConditions(
table, condiBean[0], true, true, false, true, session.Statement.needTableName())
session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.ConditionStr = strings.Join(colNames, " AND ")
session.Statement.BeanArgs = args session.Statement.BeanArgs = args
} else { } else {
// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given. // !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
// See https://github.com/go-xorm/xorm/issues/179 // See https://github.com/go-xorm/xorm/issues/179
if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped {
var colName = session.Engine.Quote(col.Name) // tag "deleted" is enabled
if addedTableName { var colName = session.Statement.colName(col, session.Statement.TableName())
var nm = session.Statement.TableName()
if len(session.Statement.TableAlias) > 0 {
nm = session.Statement.TableAlias
}
colName = session.Engine.Quote(nm) + "." + colName
}
session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')", session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL OR %v = '0001-01-01 00:00:00')",
colName, colName) colName, colName)
} }
@ -1183,28 +1182,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
var columnStr = session.Statement.ColumnStr columnStr := session.Statement.genColumnStr()
if len(session.Statement.selectStr) > 0 {
columnStr = session.Statement.selectStr
} else {
if session.Statement.JoinStr == "" {
if columnStr == "" {
if session.Statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
} else {
columnStr = session.Statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if session.Statement.GroupByStr != "" {
columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1))
} else {
columnStr = "*"
}
}
}
}
session.Statement.Params = append(session.Statement.joinArgs, append(session.Statement.Params, session.Statement.BeanArgs...)...) session.Statement.Params = append(session.Statement.joinArgs, append(session.Statement.Params, session.Statement.BeanArgs...)...)

View File

@ -40,6 +40,7 @@ type exprParam struct {
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *core.Table RefTable *core.Table
OutTable *core.Table
Engine *Engine Engine *Engine
Start int Start int
LimitN int LimitN int
@ -54,6 +55,7 @@ type Statement struct {
ColumnStr string ColumnStr string
selectStr string selectStr string
columnMap map[string]bool columnMap map[string]bool
tableMap map[string]bool
useAllCols bool useAllCols bool
OmitStr string OmitStr string
ConditionStr string ConditionStr string
@ -85,6 +87,7 @@ type Statement struct {
// Init reset all the statment's fields // Init reset all the statment's fields
func (statement *Statement) Init() { func (statement *Statement) Init() {
statement.RefTable = nil statement.RefTable = nil
statement.OutTable = nil
statement.Start = 0 statement.Start = 0
statement.LimitN = 0 statement.LimitN = 0
statement.WhereStr = "" statement.WhereStr = ""
@ -98,6 +101,7 @@ func (statement *Statement) Init() {
statement.ColumnStr = "" statement.ColumnStr = ""
statement.OmitStr = "" statement.OmitStr = ""
statement.columnMap = make(map[string]bool) statement.columnMap = make(map[string]bool)
statement.tableMap = make(map[string]bool)
statement.ConditionStr = "" statement.ConditionStr = ""
statement.AltTableName = "" statement.AltTableName = ""
statement.IdParam = nil statement.IdParam = nil
@ -141,7 +145,14 @@ func (statement *Statement) Sql(querystring string, args ...interface{}) *Statem
// Alias set the table alias // Alias set the table alias
func (statement *Statement) Alias(alias string) *Statement { func (statement *Statement) Alias(alias string) *Statement {
if statement.TableName() != "" {
statement.tableMapDelete(statement.TableName())
}
if statement.TableAlias != "" {
statement.tableMapDelete(statement.TableAlias)
}
statement.TableAlias = alias statement.TableAlias = alias
statement.tableMapAdd(alias)
return statement return statement
} }
@ -190,6 +201,9 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme
// Table tempororily set table name, the parameter could be a string or a pointer of struct // Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
if statement.TableName() != "" {
statement.tableMapDelete(statement.TableName())
}
v := rValue(tableNameOrBean) v := rValue(tableNameOrBean)
t := v.Type() t := v.Type()
if t.Kind() == reflect.String { if t.Kind() == reflect.String {
@ -197,6 +211,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.autoMapType(v) statement.RefTable = statement.Engine.autoMapType(v)
} }
statement.tableMapAdd(statement.TableName())
return statement return statement
} }
@ -440,21 +455,46 @@ func (statement *Statement) needTableName() bool {
} }
func (statement *Statement) colName(col *core.Column, tableName string) string { func (statement *Statement) colName(col *core.Column, tableName string) string {
if statement.needTableName() { return buildColName(statement.Engine, col, tableName, statement.TableAlias,
var nm = tableName statement.outTableName(), statement.needTableName(), statement.tableMap)
if len(statement.TableAlias) > 0 { }
nm = statement.TableAlias
func buildColName(engine *Engine, col *core.Column,
mainTableName, mainTableAlias, outTableName string, needTableName bool,
knownTables map[string]bool) string {
var colTable string
if needTableName {
var mainTable string
if len(mainTableAlias) > 0 {
mainTable = mainTableAlias
} else {
mainTable = mainTableName
}
if isKnownTable(engine, knownTables, mainTable, col.TableName) {
colTable = col.TableName
} else if isKnownTable(engine, knownTables, mainTable, outTableName) {
colTable = outTableName
} else {
colTable = ""
} }
return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
} }
return statement.Engine.Quote(col.Name)
if colTable != "" {
return engine.Quote(colTable) + "." + engine.Quote(col.Name)
} else {
return engine.Quote(col.Name)
}
} }
// Auto generating conditions according a struct // Auto generating conditions according a struct
func buildConditions(engine *Engine, table *core.Table, bean interface{}, func buildConditions(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) ([]string, []interface{}) { mustColumnMap map[string]bool, tableName, aliasName, outTableName string,
addedTableName bool, knownTables map[string]bool) ([]string, []interface{}) {
var colNames []string var colNames []string
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
for _, col := range table.Columns() { for _, col := range table.Columns() {
@ -475,16 +515,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
continue continue
} }
var colName string colName := buildColName(
if addedTableName { engine, col, tableName, aliasName, outTableName, addedTableName, knownTables)
var nm = tableName
if len(aliasName) > 0 {
nm = aliasName
}
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
} else {
colName = engine.Quote(col.Name)
}
fieldValuePtr, err := col.ValueOf(bean) fieldValuePtr, err := col.ValueOf(bean)
if err != nil { if err != nil {
@ -688,6 +720,13 @@ func (statement *Statement) TableName() string {
return "" return ""
} }
func (statement *Statement) outTableName() string {
if statement.OutTable != nil {
return statement.OutTable.Name
}
return ""
}
// Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) Id(id interface{}) *Statement { func (statement *Statement) Id(id interface{}) *Statement {
idValue := reflect.ValueOf(id) idValue := reflect.ValueOf(id)
@ -979,12 +1018,15 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
var refName string
switch tablename.(type) { switch tablename.(type) {
case []string: case []string:
t := tablename.([]string) t := tablename.([]string)
if len(t) > 1 { if len(t) > 1 {
refName = t[1]
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 { } else if len(t) == 1 {
refName = t[0]
fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
} }
case []interface{}: case []interface{}:
@ -1003,18 +1045,22 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
} }
} }
if l > 1 { if l > 1 {
refName = fmt.Sprintf("%v", t[1])
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) statement.Engine.Quote(refName))
} else if l == 1 { } else if l == 1 {
refName = table
fmt.Fprintf(&buf, statement.Engine.Quote(table)) fmt.Fprintf(&buf, statement.Engine.Quote(table))
} }
default: default:
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) refName = fmt.Sprintf("%v", tablename)
fmt.Fprintf(&buf, statement.Engine.Quote(refName))
} }
fmt.Fprintf(&buf, " ON %v", condition) fmt.Fprintf(&buf, " ON %v", condition)
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...) statement.joinArgs = append(statement.joinArgs, args...)
statement.tableMap[statement.Engine.Quote(strings.ToLower(refName))] = true
return statement return statement
} }
@ -1037,7 +1083,23 @@ func (statement *Statement) Unscoped() *Statement {
} }
func (statement *Statement) genColumnStr() string { func (statement *Statement) genColumnStr() string {
if len(statement.selectStr) > 0 {
return statement.selectStr
}
if len(statement.ColumnStr) > 0 {
return statement.ColumnStr
}
if len(statement.GroupByStr) > 0 {
return statement.Engine.Quote(
strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
}
table := statement.RefTable table := statement.RefTable
if statement.OutTable != nil {
table = statement.OutTable
}
colNames := make([]string, 0) colNames := make([]string, 0)
for _, col := range table.Columns() { for _, col := range table.Columns() {
if statement.OmitStr != "" { if statement.OmitStr != "" {
@ -1049,26 +1111,12 @@ func (statement *Statement) genColumnStr() string {
continue continue
} }
if statement.JoinStr != "" { name := statement.colName(col, statement.TableName())
var name string
if statement.TableAlias != "" { if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
name = statement.Engine.Quote(statement.TableAlias) colNames = append(colNames, "id() AS "+name)
} else {
name = statement.Engine.Quote(statement.TableName())
}
name += "." + statement.Engine.Quote(col.Name)
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
colNames = append(colNames, "id() AS "+name)
} else {
colNames = append(colNames, name)
}
} else { } else {
name := statement.Engine.Quote(col.Name) colNames = append(colNames, name)
if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" {
colNames = append(colNames, "id() AS "+name)
} else {
colNames = append(colNames, name)
}
} }
} }
return strings.Join(colNames, ", ") return strings.Join(colNames, ", ")
@ -1135,38 +1183,15 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
table = statement.RefTable table = statement.RefTable
} }
var addedTableName = (len(statement.JoinStr) > 0)
if !statement.noAutoCondition { if !statement.noAutoCondition {
colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) colNames, args := statement.buildConditions(
table, bean, true, true, false, true, statement.needTableName())
statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ")
statement.BeanArgs = args statement.BeanArgs = args
} }
var columnStr string = statement.ColumnStr columnStr := statement.genColumnStr()
if len(statement.selectStr) > 0 {
columnStr = statement.selectStr
} else {
// TODO: always generate column names, not use * even if join
if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
} else {
columnStr = "*"
}
}
}
}
statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)" statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)"
return statement.genSelectSQL(columnStr), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) return statement.genSelectSQL(columnStr), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...)
@ -1195,7 +1220,7 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in
func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) { func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) {
return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, statement.outTableName(), addedTableName, statement.tableMap)
} }
func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
@ -1347,3 +1372,24 @@ func (statement *Statement) processIdParam() {
} }
} }
} }
func (statement *Statement) tableMapAdd(table string) {
tableName := statement.Engine.Quote(strings.ToLower(table))
statement.tableMap[tableName] = true
}
func (statement *Statement) tableMapDelete(table string) {
tableName := statement.Engine.Quote(strings.ToLower(table))
delete(statement.tableMap, tableName)
}
func isKnownTable(engine *Engine, tableMap map[string]bool, mainTable, table string) bool {
if len(table) > 0 {
cm := engine.Quote(strings.ToLower(mainTable))
ct := engine.Quote(strings.ToLower(table))
if ct == cm || tableMap[ct] {
return true
}
}
return false
}