add support for Join

This commit is contained in:
Lunny Xiao 2014-05-23 14:18:45 +08:00
parent 14dd1f720b
commit e4f05916cb
3 changed files with 59 additions and 57 deletions

View File

@ -590,15 +590,10 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
continue continue
} }
if strings.ToUpper(tags[0]) == "EXTENDS" { if strings.ToUpper(tags[0]) == "EXTENDS" {
//fieldValue = reflect.Indirect(fieldValue)
//fmt.Println("----", fieldValue.Kind())
if fieldValue.Kind() == reflect.Struct { if fieldValue.Kind() == reflect.Struct {
//parentTable := mappingTable(fieldType, tableMapper, colMapper, dialect, tagId)
parentTable := engine.mapType(fieldValue) parentTable := engine.mapType(fieldValue)
for _, col := range parentTable.Columns() { for _, col := range parentTable.Columns() {
col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName)
//fmt.Println("---", col.FieldName)
table.AddColumn(col) table.AddColumn(col)
} }
@ -610,7 +605,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
if !fieldValue.IsValid() || fieldValue.IsNil() { if !fieldValue.IsValid() || fieldValue.IsNil() {
fieldValue = reflect.New(f).Elem() fieldValue = reflect.New(f).Elem()
} }
//fmt.Println("00000", fieldValue)
} }
parentTable := engine.mapType(fieldValue) parentTable := engine.mapType(fieldValue)
@ -732,7 +726,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
if col.Length2 == 0 { if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2 col.Length2 = col.SQLType.DefaultLength2
} }
//fmt.Println("======", col)
if col.Name == "" { if col.Name == "" {
col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name) col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name)
} }

View File

@ -599,7 +599,6 @@ func (statement *Statement) convertIdSql(sqlStr string) string {
if len(sqls) != 2 { if len(sqls) != 2 {
return "" return ""
} }
//fmt.Println("-----", col)
newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()),
statement.Engine.Quote(col.Name), sqls[1]) statement.Engine.Quote(col.Name), sqls[1])
return newsql return newsql
@ -728,7 +727,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
ids = make([]core.PK, 0) ids = make([]core.PK, 0)
if len(resultsSlice) > 0 { if len(resultsSlice) > 0 {
for _, data := range resultsSlice { for _, data := range resultsSlice {
//fmt.Println(data)
var id int64 var id int64
if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok {
return errors.New("no id") return errors.New("no id")
@ -939,7 +937,9 @@ func (session *Session) Get(bean interface{}) (bool, error) {
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
session.Statement.RefTable = session.Engine.autoMap(bean) if session.Statement.RefTable == nil {
session.Statement.RefTable = session.Engine.autoMap(bean)
}
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
sqlStr, args = session.Statement.genGetSql(bean) sqlStr, args = session.Statement.genGetSql(bean)
@ -948,10 +948,12 @@ func (session *Session) Get(bean interface{}) (bool, error) {
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache { if session.Statement.JoinStr == "" {
has, err := session.cacheGet(bean, sqlStr, args...) if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache {
if err != ErrCacheFailed { has, err := session.cacheGet(bean, sqlStr, args...)
return has, err if err != ErrCacheFailed {
return has, err
}
} }
} }
@ -1073,8 +1075,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var args []interface{} var args []interface{}
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
var columnStr string = session.Statement.ColumnStr var columnStr string = session.Statement.ColumnStr
if columnStr == "" { if session.Statement.JoinStr == "" {
columnStr = session.Statement.genColumnStr() if columnStr == "" {
columnStr = session.Statement.genColumnStr()
}
} else {
if columnStr == "" {
columnStr = "*"
}
} }
session.Statement.attachInSql() session.Statement.attachInSql()
@ -1086,15 +1094,17 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if cacher := session.Engine.getCacher2(table); cacher != nil && if session.Statement.JoinStr == "" {
session.Statement.UseCache && if cacher := session.Engine.getCacher2(table); cacher != nil &&
!session.Statement.IsDistinct { session.Statement.UseCache &&
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) !session.Statement.IsDistinct {
if err != ErrCacheFailed { err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
return err if err != ErrCacheFailed {
return err
}
err = nil // !nashtsai! reset err to nil for ErrCacheFailed
session.Engine.LogWarn("Cache Find Failed")
} }
err = nil // !nashtsai! reset err to nil for ErrCacheFailed
session.Engine.LogWarn("Cache Find Failed")
} }
if sliceValue.Kind() != reflect.Map { if sliceValue.Kind() != reflect.Map {
@ -1102,14 +1112,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var stmt *core.Stmt var stmt *core.Stmt
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
// err = session.queryRows(&stmt, &rawRows, sqlStr, args...)
// if err != nil {
// return err
// }
// if stmt != nil {
// defer stmt.Close()
// }
// defer rawRows.Close()
if session.IsAutoCommit { if session.IsAutoCommit {
stmt, err = session.doPrepare(sqlStr) stmt, err = session.doPrepare(sqlStr)
@ -1311,7 +1313,6 @@ func (session *Session) addColumn(colName string) error {
if session.IsAutoClose { if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
//fmt.Println(session.Statement.RefTable)
col := session.Statement.RefTable.GetColumn(colName) col := session.Statement.RefTable.GetColumn(colName)
sql, args := session.Statement.genAddColumnStr(col) sql, args := session.Statement.genAddColumnStr(col)
@ -1344,7 +1345,6 @@ func (session *Session) addUnique(tableName, uqeName string) error {
if session.IsAutoClose { if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
//fmt.Println(uqeName, session.Statement.RefTable.Uniques)
index := session.Statement.RefTable.Indexes[uqeName] index := session.Statement.RefTable.Indexes[uqeName]
sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index) sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index)
_, err = session.exec(sqlStr) _, err = session.exec(sqlStr)
@ -1402,9 +1402,9 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er
return result, nil return result, nil
} }
func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table) *reflect.Value { func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value {
var col *core.Column var col *core.Column
if col = table.GetColumn(key); col == nil { if col = table.GetColumnIdx(key, idx); col == nil {
session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.Columns())) session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.Columns()))
return nil return nil
} }
@ -1448,13 +1448,22 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
} }
} }
var tempMap = make(map[string]int)
for ii, key := range fields { for ii, key := range fields {
if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { var idx int
var ok bool
if idx, ok = tempMap[strings.ToLower(key)]; !ok {
idx = 0
} else {
idx = idx + 1
}
tempMap[strings.ToLower(key)] = idx
if fieldValue := session.getField(&dataStruct, key, table, idx); fieldValue != nil {
rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii]))
//if row is null then ignore //if row is null then ignore
if rawValue.Interface() == nil { if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue continue
} }
@ -1485,9 +1494,6 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
vv := reflect.ValueOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
//fmt.Println("column name:", key, ", fieldType:", fieldType.String())
hasAssigned := false hasAssigned := false
switch fieldType.Kind() { switch fieldType.Kind() {
@ -1767,7 +1773,6 @@ func query(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []ma
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
//fmt.Println(rows)
return rows2maps(rows) return rows2maps(rows)
} }
@ -2034,11 +2039,9 @@ func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.T
} }
sdata = strings.TrimSpace(sdata) sdata = strings.TrimSpace(sdata)
//fmt.Println(sdata)
if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:] sdata = sdata[len(sdata)-8:]
} }
//fmt.Println(sdata)
st := fmt.Sprintf("2006-01-02 %v", sdata) st := fmt.Sprintf("2006-01-02 %v", sdata)
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, session.Engine.TZLocation) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, session.Engine.TZLocation)
@ -2069,7 +2072,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
key := col.Name key := col.Name
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
//fmt.Println("column name:", key, ", fieldType:", fieldType.String())
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
x := reflect.New(fieldType) x := reflect.New(fieldType)
@ -2578,7 +2580,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces := strings.Repeat("?, ", len(colNames))
//fmt.Println(colNames, args)
colPlaces = colPlaces[0 : len(colPlaces)-2] colPlaces = colPlaces[0 : len(colPlaces)-2]
sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
@ -2988,7 +2989,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if session.Statement.UseAutoTime && table.Updated != "" { if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, session.Engine.NowTime(table.Columns()[strings.ToLower(table.Updated)].SQLType.Name)) args = append(args, session.Engine.NowTime(table.UpdatedColumn().SQLType.Name))
} }
//for update action to like "column = column + ?" //for update action to like "column = column + ?"

View File

@ -277,8 +277,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if !includeAutoIncr && col.IsAutoIncrement { if !includeAutoIncr && col.IsAutoIncrement {
continue continue
} }
//
//fmt.Println(engine.dialect.DBType(), Text)
if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
continue continue
} }
@ -382,7 +381,6 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
continue continue
} }
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
//fmt.Println("-------", t, val, col.Name)
} else { } else {
engine.autoMapType(fieldValue) engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {
@ -470,8 +468,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
if !includeAutoIncr && col.IsAutoIncrement { if !includeAutoIncr && col.IsAutoIncrement {
continue continue
} }
//
//fmt.Println(engine.dialect.DBType(), Text)
if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
continue continue
} }
@ -555,7 +552,6 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
continue continue
} }
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
//fmt.Println("-------", t, val, col.Name)
} else { } else {
engine.autoMapType(fieldValue) engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {
@ -948,8 +944,13 @@ func (s *Statement) genDropSQL() string {
} }
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.autoMap(bean) var table *core.Table
statement.RefTable = table if statement.RefTable == nil {
table = statement.Engine.autoMap(bean)
statement.RefTable = table
} else {
table = statement.RefTable
}
colNames, args := buildConditions(statement.Engine, table, bean, true, true, colNames, args := buildConditions(statement.Engine, table, bean, true, true,
false, true, statement.allUseBool, statement.useAllCols, false, true, statement.allUseBool, statement.useAllCols,
@ -959,8 +960,14 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
statement.BeanArgs = args statement.BeanArgs = args
var columnStr string = statement.ColumnStr var columnStr string = statement.ColumnStr
if columnStr == "" { if statement.JoinStr == "" {
columnStr = statement.genColumnStr() if columnStr == "" {
columnStr = statement.genColumnStr()
}
} else {
if columnStr == "" {
columnStr = "*"
}
} }
statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)" statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)"