From e4f05916cb0242b8dd5946ae17e572ff16c0bcf0 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 23 May 2014 14:18:45 +0800 Subject: [PATCH] add support for Join --- engine.go | 8 +----- session.go | 81 ++++++++++++++++++++++++++-------------------------- statement.go | 27 +++++++++++------- 3 files changed, 59 insertions(+), 57 deletions(-) diff --git a/engine.go b/engine.go index 107a9d69..e265a5f9 100644 --- a/engine.go +++ b/engine.go @@ -590,15 +590,10 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { continue } if strings.ToUpper(tags[0]) == "EXTENDS" { - - //fieldValue = reflect.Indirect(fieldValue) - //fmt.Println("----", fieldValue.Kind()) if fieldValue.Kind() == reflect.Struct { - //parentTable := mappingTable(fieldType, tableMapper, colMapper, dialect, tagId) parentTable := engine.mapType(fieldValue) for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", t.Field(i).Name, col.FieldName) - //fmt.Println("---", col.FieldName) table.AddColumn(col) } @@ -610,7 +605,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { if !fieldValue.IsValid() || fieldValue.IsNil() { fieldValue = reflect.New(f).Elem() } - //fmt.Println("00000", fieldValue) } parentTable := engine.mapType(fieldValue) @@ -732,7 +726,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { if col.Length2 == 0 { col.Length2 = col.SQLType.DefaultLength2 } - //fmt.Println("======", col) + if col.Name == "" { col.Name = engine.ColumnMapper.Obj2Table(t.Field(i).Name) } diff --git a/session.go b/session.go index e35b37c2..591850af 100644 --- a/session.go +++ b/session.go @@ -599,7 +599,6 @@ func (statement *Statement) convertIdSql(sqlStr string) string { if len(sqls) != 2 { return "" } - //fmt.Println("-----", col) newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(col.Name), sqls[1]) return newsql @@ -728,7 +727,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = make([]core.PK, 0) if len(resultsSlice) > 0 { for _, data := range resultsSlice { - //fmt.Println(data) var id int64 if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") @@ -939,7 +937,9 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sqlStr string var args []interface{} - session.Statement.RefTable = session.Engine.autoMap(bean) + if session.Statement.RefTable == nil { + session.Statement.RefTable = session.Engine.autoMap(bean) + } if session.Statement.RawSQL == "" { sqlStr, args = session.Statement.genGetSql(bean) @@ -948,10 +948,12 @@ func (session *Session) Get(bean interface{}) (bool, error) { args = session.Statement.RawParams } - if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache { - has, err := session.cacheGet(bean, sqlStr, args...) - if err != ErrCacheFailed { - return has, err + if session.Statement.JoinStr == "" { + if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache { + has, err := session.cacheGet(bean, sqlStr, args...) + if err != ErrCacheFailed { + return has, err + } } } @@ -1073,8 +1075,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr - if columnStr == "" { - columnStr = session.Statement.genColumnStr() + if session.Statement.JoinStr == "" { + if columnStr == "" { + columnStr = session.Statement.genColumnStr() + } + } else { + if columnStr == "" { + columnStr = "*" + } } session.Statement.attachInSql() @@ -1086,15 +1094,17 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) args = session.Statement.RawParams } - if cacher := session.Engine.getCacher2(table); cacher != nil && - session.Statement.UseCache && - !session.Statement.IsDistinct { - err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) - if err != ErrCacheFailed { - return err + if session.Statement.JoinStr == "" { + if cacher := session.Engine.getCacher2(table); cacher != nil && + session.Statement.UseCache && + !session.Statement.IsDistinct { + err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) + if err != ErrCacheFailed { + return err + } + err = nil // !nashtsai! reset err to nil for ErrCacheFailed + session.Engine.LogWarn("Cache Find Failed") } - err = nil // !nashtsai! reset err to nil for ErrCacheFailed - session.Engine.LogWarn("Cache Find Failed") } if sliceValue.Kind() != reflect.Map { @@ -1102,14 +1112,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var stmt *core.Stmt session.queryPreprocess(&sqlStr, args...) - // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) - // if err != nil { - // return err - // } - // if stmt != nil { - // defer stmt.Close() - // } - // defer rawRows.Close() if session.IsAutoCommit { stmt, err = session.doPrepare(sqlStr) @@ -1311,7 +1313,6 @@ func (session *Session) addColumn(colName string) error { if session.IsAutoClose { defer session.Close() } - //fmt.Println(session.Statement.RefTable) col := session.Statement.RefTable.GetColumn(colName) sql, args := session.Statement.genAddColumnStr(col) @@ -1344,7 +1345,6 @@ func (session *Session) addUnique(tableName, uqeName string) error { if session.IsAutoClose { defer session.Close() } - //fmt.Println(uqeName, session.Statement.RefTable.Uniques) index := session.Statement.RefTable.Indexes[uqeName] sqlStr := session.Engine.dialect.CreateIndexSql(tableName, index) _, err = session.exec(sqlStr) @@ -1402,9 +1402,9 @@ func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, er return result, nil } -func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table) *reflect.Value { +func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table, idx int) *reflect.Value { var col *core.Column - if col = table.GetColumn(key); col == nil { + if col = table.GetColumnIdx(key, idx); col == nil { session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.Columns())) return nil } @@ -1448,13 +1448,22 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i } } + var tempMap = make(map[string]int) for ii, key := range fields { - if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { + var idx int + var ok bool + if idx, ok = tempMap[strings.ToLower(key)]; !ok { + idx = 0 + } else { + idx = idx + 1 + } + tempMap[strings.ToLower(key)] = idx + + if fieldValue := session.getField(&dataStruct, key, table, idx); fieldValue != nil { rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) //if row is null then ignore if rawValue.Interface() == nil { - //fmt.Println("ignore ...", key, rawValue) continue } @@ -1485,9 +1494,6 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i vv := reflect.ValueOf(rawValue.Interface()) fieldType := fieldValue.Type() - - //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) - hasAssigned := false switch fieldType.Kind() { @@ -1767,7 +1773,6 @@ func query(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []ma return nil, err } defer rows.Close() - //fmt.Println(rows) return rows2maps(rows) } @@ -2034,11 +2039,9 @@ func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.T } sdata = strings.TrimSpace(sdata) - //fmt.Println(sdata) if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 { sdata = sdata[len(sdata)-8:] } - //fmt.Println(sdata) st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.ParseInLocation("2006-01-02 15:04:05", st, session.Engine.TZLocation) @@ -2069,7 +2072,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, key := col.Name fieldType := fieldValue.Type() - //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) switch fieldType.Kind() { case reflect.Complex64, reflect.Complex128: x := reflect.New(fieldType) @@ -2578,7 +2580,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } colPlaces := strings.Repeat("?, ", len(colNames)) - //fmt.Println(colNames, args) colPlaces = colPlaces[0 : len(colPlaces)-2] sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", @@ -2988,7 +2989,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.UseAutoTime && table.Updated != "" { colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") - args = append(args, session.Engine.NowTime(table.Columns()[strings.ToLower(table.Updated)].SQLType.Name)) + args = append(args, session.Engine.NowTime(table.UpdatedColumn().SQLType.Name)) } //for update action to like "column = column + ?" diff --git a/statement.go b/statement.go index 707f9c1d..c19247a2 100644 --- a/statement.go +++ b/statement.go @@ -277,8 +277,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if !includeAutoIncr && col.IsAutoIncrement { continue } - // - //fmt.Println(engine.dialect.DBType(), Text) + if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } @@ -382,7 +381,6 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, continue } val = engine.FormatTime(col.SQLType.Name, t) - //fmt.Println("-------", t, val, col.Name) } else { engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { @@ -470,8 +468,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if !includeAutoIncr && col.IsAutoIncrement { continue } - // - //fmt.Println(engine.dialect.DBType(), Text) + if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } @@ -555,7 +552,6 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, continue } val = engine.FormatTime(col.SQLType.Name, t) - //fmt.Println("-------", t, val, col.Name) } else { engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { @@ -948,8 +944,13 @@ func (s *Statement) genDropSQL() string { } func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { - table := statement.Engine.autoMap(bean) - statement.RefTable = table + var table *core.Table + if statement.RefTable == nil { + table = statement.Engine.autoMap(bean) + statement.RefTable = table + } else { + table = statement.RefTable + } colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, true, statement.allUseBool, statement.useAllCols, @@ -959,8 +960,14 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) statement.BeanArgs = args var columnStr string = statement.ColumnStr - if columnStr == "" { - columnStr = statement.genColumnStr() + if statement.JoinStr == "" { + if columnStr == "" { + columnStr = statement.genColumnStr() + } + } else { + if columnStr == "" { + columnStr = "*" + } } statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)"