diff --git a/base_test.go b/base_test.go index 84770168..66631160 100644 --- a/base_test.go +++ b/base_test.go @@ -359,7 +359,7 @@ func updateSameMapper(engine *Engine, t *testing.T) { } } -func testdelete(engine *Engine, t *testing.T) { +func testDelete(engine *Engine, t *testing.T) { user := Userinfo{Uid: 1} cnt, err := engine.Delete(&user) if err != nil { @@ -3788,8 +3788,8 @@ func testAll(engine *Engine, t *testing.T) { insertTwoTable(engine, t) fmt.Println("-------------- update --------------") update(engine, t) - fmt.Println("-------------- testdelete --------------") - testdelete(engine, t) + fmt.Println("-------------- testDelete --------------") + testDelete(engine, t) fmt.Println("-------------- get --------------") get(engine, t) fmt.Println("-------------- cascadeGet --------------") diff --git a/helpers.go b/helpers.go index 307353c2..96f118f2 100644 --- a/helpers.go +++ b/helpers.go @@ -1,8 +1,12 @@ package xorm import ( + "database/sql" + "fmt" "reflect" + "strconv" "strings" + "time" ) func indexNoCase(s, sep string) int { @@ -61,3 +65,72 @@ func sliceEq(left, right []string) bool { return true } + +func value2Bytes(rawValue *reflect.Value) (data []byte, err error) { + + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + + var str string + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + data = []byte(str) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + data = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + data = []byte(str) + case reflect.String: + str = vv.String() + data = []byte(str) + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data = rawValue.Interface().([]byte) + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + //时间类型 + case reflect.Struct: + if aa == reflect.TypeOf(c_TIME_DEFAULT) { + str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) + data = []byte(str) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + data = []byte(str) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + data = []byte(str) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + +func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2map(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} diff --git a/rows.go b/rows.go index 7aeb87f0..0ac6c956 100644 --- a/rows.go +++ b/rows.go @@ -9,12 +9,13 @@ import ( type Rows struct { NoTypeCheck bool - session *Session - stmt *sql.Stmt - rows *sql.Rows - fields []string - beanType reflect.Type - lastError error + session *Session + stmt *sql.Stmt + rows *sql.Rows + fields []string + fieldsCount int + beanType reflect.Type + lastError error } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -66,6 +67,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { defer rows.Close() return nil, err } + rows.fieldsCount = len(rows.fields) return rows, nil } @@ -97,11 +99,13 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - result, err := row2map(rows.rows, rows.fields) // !nashtsai! TODO remove row2map then scanMapIntoStruct conversation for better performance - if err == nil { - err = rows.session.scanMapIntoStruct(bean, result) - } - return err + return rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean) + + // result, err := row2map(rows.rows, rows.fields) // !nashtsai! TODO remove row2map then scanMapIntoStruct conversation for better performance + // if err == nil { + // err = rows.session.scanMapIntoStruct(bean, result) + // } + // return err } // // Columns returns the column names. Columns returns an error if the rows are closed, or if the rows are from QueryRow and there was a deferred error. diff --git a/session.go b/session.go index 883476bb..7623009b 100644 --- a/session.go +++ b/session.go @@ -186,8 +186,8 @@ func (session *Session) Desc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " DESC" return session } @@ -197,8 +197,8 @@ func (session *Session) Asc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " ASC" return session } @@ -394,8 +394,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b } //Execute sql -func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(sql) +func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(sqlStr) if err != nil { return nil, err } @@ -408,22 +408,22 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, return res, nil } -func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - session.Engine.LogSQL(sql) + session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(args) if session.IsAutoCommit { - return session.innerExec(sql, args...) + return session.innerExec(sqlStr, args...) } - return session.Tx.Exec(sql, args...) + return session.Tx.Exec(sqlStr, args...) } // Exec raw sql -func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { err := session.newDb() if err != nil { return nil, err @@ -433,7 +433,7 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error defer session.Close() } - return session.exec(sql, args...) + return session.exec(sqlStr, args...) } // this function create a table according a bean @@ -466,8 +466,8 @@ func (session *Session) CreateIndexes(bean interface{}) error { } sqls := session.Statement.genIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -489,8 +489,8 @@ func (session *Session) CreateUniques(bean interface{}) error { } sqls := session.Statement.genUniqueSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -499,9 +499,9 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createOneTable() error { - sql := session.Statement.genCreateTableSQL() - session.Engine.LogDebug("create table sql: [", sql, "]") - _, err := session.exec(sql) + sqlStr := session.Statement.genCreateTableSQL() + session.Engine.LogDebug("create table sql: [", sqlStr, "]") + _, err := session.exec(sqlStr) return err } @@ -538,8 +538,8 @@ func (session *Session) DropIndexes(bean interface{}) error { } sqls := session.Statement.genDelIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -569,16 +569,16 @@ func (session *Session) DropTable(bean interface{}) error { return errors.New("Unsupported type") } - sql := session.Statement.genDropSQL() - _, err = session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err = session.exec(sqlStr) return err } -func (statement *Statement) convertIdSql(sql string) string { +func (statement *Statement) convertIdSql(sqlStr string) string { if statement.RefTable != nil { col := statement.RefTable.PKColumns()[0] if col != nil { - sqls := splitNNoCase(sql, "from", 2) + sqls := splitNNoCase(sqlStr, "from", 2) if len(sqls) != 2 { return "" } @@ -590,15 +590,15 @@ func (statement *Statement) convertIdSql(sql string) string { return "" } -func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { +func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { // if has no reftable or number of pks is not equal to 1, then don't use cache currently if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return false, ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return false, ErrCacheFailed } @@ -670,19 +670,19 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface return false, nil } -func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { +func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 || - indexNoCase(sql, "having") != -1 || - indexNoCase(sql, "group by") != -1 { + indexNoCase(sqlStr, "having") != -1 || + indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -881,42 +881,68 @@ func (session *Session) Get(bean interface{}) (bool, error) { } session.Statement.Limit(1) - var sql string + var sqlStr string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) + sqlStr, args = session.Statement.genGetSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { - has, err := session.cacheGet(bean, sql, args...) + has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err } } - resultsSlice, err := session.query(sql, args...) + var rawRows *sql.Rows + session.queryPreprocess(&sqlStr, args...) + if session.IsAutoCommit { + stmt, err := session.Db.Prepare(sqlStr) + if err != nil { + return false, err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) + } else { + rawRows, err = session.Tx.Query(sqlStr, args...) + } if err != nil { return false, err } - if len(resultsSlice) < 1 { + defer rawRows.Close() + + if rawRows.Next() { + if fields, err := rawRows.Columns(); err == nil { + err = session.row2Bean(rawRows, fields, len(fields), bean) + } + return true, err + } else { return false, nil } - err = session.scanMapIntoStruct(bean, resultsSlice[0]) - if err != nil { - return true, err - } - if len(resultsSlice) == 1 { - return true, nil - } else { - return true, errors.New("More than one record") - } + // resultsSlice, err := session.query(sqlStr, args...) + // if err != nil { + // return false, err + // } + // if len(resultsSlice) < 1 { + // return false, nil + // } + + // err = session.scanMapIntoStruct(bean, resultsSlice[0]) + // if err != nil { + // return true, err + // } + // if len(resultsSlice) == 1 { + // return true, nil + // } else { + // return true, errors.New("More than one record") + // } } // Count counts the records. bean's non-empty fields @@ -932,16 +958,16 @@ func (session *Session) Count(bean interface{}) (int64, error) { defer session.Close() } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - sql, args = session.Statement.genCountSql(bean) + sqlStr, args = session.Statement.genCountSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } - resultsSlice, err := session.query(sql, args...) + resultsSlice, err := session.query(sqlStr, args...) if err != nil { return 0, err } @@ -1002,7 +1028,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.BeanArgs = args } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr @@ -1012,46 +1038,96 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.attachInSql() - sql = session.Statement.genSelectSql(columnStr) + sqlStr = session.Statement.genSelectSql(columnStr) args = append(session.Statement.Params, session.Statement.BeanArgs...) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if table.Cacher != nil && session.Statement.UseCache && !session.Statement.IsDistinct { - err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) + err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } + err = nil // !nashtsai! reset err to nil for ErrCacheFailed session.Engine.LogWarn("Cache Find Failed") } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return err - } + if sliceValue.Kind() != reflect.Map { + var rawRows *sql.Rows + var stmt *sql.Stmt - for i, results := range resultsSlice { - var newValue reflect.Value - if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) + session.queryPreprocess(&sqlStr, args...) + // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) + // if err != nil { + // return err + // } + // if stmt != nil { + // defer stmt.Close() + // } + // defer rawRows.Close() + + if session.IsAutoCommit { + stmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) } else { - newValue = reflect.New(sliceElementType) + rawRows, err = session.Tx.Query(sqlStr, args...) } - err := session.scanMapIntoStruct(newValue.Interface(), results) if err != nil { return err } - if sliceValue.Kind() == reflect.Slice { + defer rawRows.Close() + + fields, err := rawRows.Columns() + if err != nil { + return err + } + + fieldsCount := len(fields) + + for rawRows.Next() { + var newValue reflect.Value if sliceElementType.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + newValue = reflect.New(sliceElementType.Elem()) } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + newValue = reflect.New(sliceElementType) + } + err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface()) + if err != nil { + return err + } + if sliceValue.Kind() == reflect.Slice { + if sliceElementType.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + } + } + } else { + resultsSlice, err := session.query(sqlStr, args...) + if err != nil { + return err + } + + for i, results := range resultsSlice { + var newValue reflect.Value + if sliceElementType.Kind() == reflect.Ptr { + newValue = reflect.New(sliceElementType.Elem()) + } else { + newValue = reflect.New(sliceElementType) + } + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err } - } else if sliceValue.Kind() == reflect.Map { var key int64 // if there is only one pk, we can put the id as map key. // TODO: should know if the column is ints @@ -1074,6 +1150,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } +func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { + var err error + if session.IsAutoCommit { + *rawStmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return err + } + *rawRows, err = (*rawStmt).Query(args...) + } else { + *rawRows, err = session.Tx.Query(sqlStr, args...) + } + return err +} + // Test if database is ok func (session *Session) Ping() error { err := session.newDb() @@ -1097,8 +1187,8 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1111,8 +1201,8 @@ func (session *Session) isTableExist(tableName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1131,8 +1221,8 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo } else { idx = indexName(tableName, idxName) } - sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1165,6 +1255,7 @@ func (session *Session) addColumn(colName string) error { defer session.Close() } //fmt.Println(session.Statement.RefTable) + col := session.Statement.RefTable.Columns[strings.ToLower(colName)] sql, args := session.Statement.genAddColumnStr(col) _, err = session.exec(sql, args...) @@ -1182,8 +1273,8 @@ func (session *Session) addIndex(tableName, idxName string) error { } //fmt.Println(idxName) cols := session.Statement.RefTable.Indexes[idxName].Cols - sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1198,8 +1289,8 @@ func (session *Session) addUnique(tableName, uqeName string) error { } //fmt.Println(uqeName, session.Statement.RefTable.Uniques) cols := session.Statement.RefTable.Indexes[uqeName].Cols - sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1217,8 +1308,8 @@ func (session *Session) dropAll() error { for _, table := range session.Engine.Tables { session.Statement.Init() session.Statement.RefTable = table - sql := session.Statement.genDropSQL() - _, err := session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err := session.exec(sqlStr) if err != nil { return err } @@ -1237,8 +1328,6 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err return nil, err } - // !nashtsai! TODO optimization for query performance, where current process has gone from - // sql driver converted type back to []bytes then to ORM's fields for ii, key := range fields { rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) //if row is null then ignore @@ -1246,90 +1335,336 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err //fmt.Println("ignore ...", key, rawValue) continue } - aa := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - var str string - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - result[key] = []byte(str) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - result[key] = []byte(str) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - result[key] = []byte(str) - case reflect.String: - str = vv.String() - result[key] = []byte(str) - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - result[key] = rawValue.Interface().([]byte) - str = string(result[key]) - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - //时间类型 - case reflect.Struct: - if aa == reflect.TypeOf(c_TIME_DEFAULT) { - str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) - result[key] = []byte(str) - } else { - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - result[key] = []byte(str) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - result[key] = []byte(str) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) + + if data, err := value2Bytes(&rawValue); err == nil { + result[key] = data + } else { + return nil, err // !nashtsai! REVIEW, should return err or just error log? } } return result, nil } -func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } +func (session *Session) getField(dataStruct *reflect.Value, key string, table *Table) *reflect.Value { - return resultsSlice, nil + key = strings.ToLower(key) + if _, ok := table.Columns[key]; !ok { + session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) + return nil + } + col := table.Columns[key] + fieldName := col.FieldName + fieldPath := strings.Split(fieldName, ".") + var fieldValue reflect.Value + if len(fieldPath) > 2 { + session.Engine.LogError("Unsupported mutliderive", fieldName) + return nil + } else if len(fieldPath) == 2 { + parentField := dataStruct.FieldByName(fieldPath[0]) + if parentField.IsValid() { + fieldValue = parentField.FieldByName(fieldPath[1]) + } + } else { + fieldValue = dataStruct.FieldByName(fieldName) + } + if !fieldValue.IsValid() || !fieldValue.CanSet() { + session.Engine.LogWarn("table %v's column %v is not valid or cannot set", + table.Name, key) + return nil + } + return &fieldValue } -func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) +func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { + + dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + if dataStruct.Kind() != reflect.Struct { + return errors.New("Expected a pointer to a struct") } - session.Engine.LogSQL(sql) + table := session.Engine.autoMapType(rType(bean)) + + var scanResultContainers []interface{} + for i := 0; i < fieldsCount; i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := rows.Scan(scanResultContainers...); err != nil { + return err + } + + for ii, key := range fields { + if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { + + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + //fmt.Println("ignore ...", key, rawValue) + continue + } + + if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + structConvert.FromDB(data) + } else { + session.Engine.LogError(err) + } + continue + } + + rawValueType := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + + fieldType := fieldValue.Type() + + //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) + + hasAssigned := false + + switch fieldType.Kind() { + + case reflect.Complex64, reflect.Complex128: + if rawValueType.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } + case reflect.Slice, reflect.Array: + switch rawValueType.Kind() { + case reflect.Slice, reflect.Array: + switch rawValueType.Elem().Kind() { + case reflect.Uint8: + if fieldType.Elem().Kind() == reflect.Uint8 { + hasAssigned = true + fieldValue.Set(vv) + } + } + } + case reflect.String: + if rawValueType.Kind() == reflect.String { + hasAssigned = true + fieldValue.SetString(vv.String()) + } + case reflect.Bool: + if rawValueType.Kind() == reflect.Bool { + hasAssigned = true + fieldValue.SetBool(vv.Bool()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch rawValueType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetInt(vv.Int()) + } + case reflect.Float32, reflect.Float64: + switch rawValueType.Kind() { + case reflect.Float32, reflect.Float64: + hasAssigned = true + fieldValue.SetFloat(vv.Float()) + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + switch rawValueType.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + hasAssigned = true + fieldValue.SetUint(vv.Uint()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetUint(uint64(vv.Int())) + } + case reflect.Struct: + if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + fieldValue.Set(vv) + } + } else if session.Statement.UseCascade { + table := session.Engine.autoMapType(fieldValue.Type()) + if table != nil { + var x int64 + if rawValueType.Kind() == reflect.Int64 { + x = vv.Int() + } + if x != 0 { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(x).Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v := structInter.Elem().Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } + } + } else { + session.Engine.LogError("unsupported struct type in Scan: ", fieldValue.Type().String()) + } + } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + //typeStr := fieldType.String() + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case reflect.TypeOf(&c_EMPTY_STRING): + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_BOOL_DEFAULT): + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_TIME_DEFAULT): + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + var x time.Time = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_FLOAT64_DEFAULT): + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint64 = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_FLOAT32_DEFAULT): + if rawValueType.Kind() == reflect.Float64 { + var x float32 = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT32_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int32 = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT8_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int8 = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT16_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x int16 = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT32_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint32 = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT8_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint8 = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT16_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint16 = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_COMPLEX64_DEFAULT): + var x complex64 + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.Engine.LogError(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + case reflect.TypeOf(&c_COMPLEX128_DEFAULT): + var x complex128 + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.Engine.LogError(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + } // switch fieldType + // default: + // session.Engine.LogError("unsupported type in Scan: ", reflect.TypeOf(v).String()) + } // switch fieldType.Kind() + + // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value + if !hasAssigned { + data, err := value2Bytes(&rawValue) + if err == nil { + session.bytes2Value(table.Columns[key], fieldValue, data) + } else { + session.Engine.LogError(err.Error()) + } + } + } + } + return nil + +} + +func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { + for _, filter := range session.Engine.Filters { + *sqlStr = filter.Do(*sqlStr, session) + } + + session.Engine.LogSQL(*sqlStr) session.Engine.LogSQL(paramStr) +} + +func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + session.queryPreprocess(&sqlStr, paramStr...) if session.IsAutoCommit { - return query(session.Db, sql, paramStr...) + return query(session.Db, sqlStr, paramStr...) } - return txQuery(session.Tx, sql, paramStr...) + return txQuery(session.Tx, sqlStr, paramStr...) } -func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - rows, err := tx.Query(sql, params...) +func txQuery(tx *sql.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + rows, err := tx.Query(sqlStr, params...) if err != nil { return nil, err } @@ -1338,8 +1673,8 @@ func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[ return rows2maps(rows) } -func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - s, err := db.Prepare(sql) +func query(db *sql.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + s, err := db.Prepare(sqlStr) if err != nil { return nil, err } @@ -1350,12 +1685,11 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st } defer rows.Close() //fmt.Println(rows) - return rows2maps(rows) } // Exec a raw sql and return records as []map[string][]byte -func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) Query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { err = session.newDb() if err != nil { return nil, err @@ -1365,7 +1699,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice defer session.Close() } - return session.query(sql, paramStr...) + return session.query(sqlStr, paramStr...) } // insert one or more beans @@ -1639,7 +1973,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1651,7 +1985,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x := reflect.New(fieldType) err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1662,7 +1996,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x := reflect.New(fieldType) err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1676,7 +2010,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data d := string(data) v, err := strconv.ParseBool(d) if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) + return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(v)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -1704,19 +2038,19 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 10, 64) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.SetInt(x) case reflect.Float32, reflect.Float64: x, err := strconv.ParseFloat(string(data), 64) if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) + return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.SetFloat(x) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.SetUint(x) //Currently only support Time type @@ -1733,7 +2067,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if table != nil { x, err := strconv.ParseInt(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } if x != 0 { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch @@ -1754,7 +2088,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } } } else { - return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) + return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } case reflect.Ptr: @@ -1770,7 +2104,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data d := string(data) v, err := strconv.ParseBool(d) if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) + return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&v)) // case "*complex64": @@ -1778,7 +2112,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x complex64 err := json.Unmarshal(data, &x) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(reflect.ValueOf(&x)) @@ -1787,7 +2121,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x complex128 err := json.Unmarshal(data, &x) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(reflect.ValueOf(&x)) @@ -1795,7 +2129,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data case reflect.TypeOf(&c_FLOAT64_DEFAULT): x, err := strconv.ParseFloat(string(data), 64) if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) + return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*float32": @@ -1803,7 +2137,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x float32 x1, err := strconv.ParseFloat(string(data), 32) if err != nil { - return errors.New("arg " + key + " as float32: " + err.Error()) + return fmt.Errorf("arg %v as float32: %s", key, err.Error()) } x = float32(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1820,7 +2154,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*uint": @@ -1828,7 +2162,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1837,7 +2171,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint32(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1846,7 +2180,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint8(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1855,7 +2189,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint16(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1881,7 +2215,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 10, 64) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int": @@ -1910,7 +2244,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int32": @@ -1939,7 +2273,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int32(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int8": @@ -1968,7 +2302,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int8(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int16": @@ -1997,14 +2331,14 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int16(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } return nil @@ -2144,7 +2478,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces = colPlaces[0 : len(colPlaces)-2] - sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", + sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", session.Engine.QuoteStr(), session.Statement.TableName(), session.Engine.QuoteStr(), @@ -2184,8 +2518,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. + if session.Engine.DriverName != POSTGRES || table.AutoIncrement == "" { - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } else { @@ -2236,8 +2571,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } else { //assert table.AutoIncrement != "" - sql = sql + " RETURNING " + session.Engine.Quote(table.AutoIncrement) - res, err := session.query(sql, args...) + sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) + res, err := session.query(sqlStr, args...) + if err != nil { return 0, err } else { @@ -2305,11 +2641,11 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (statement *Statement) convertUpdateSql(sql string) (string, string) { +func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) { if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { return "", "" } - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", @@ -2361,12 +2697,12 @@ func (session *Session) cacheInsert(tables ...string) error { return nil } -func (session *Session) cacheUpdate(sql string, args ...interface{}) error { +func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } - oldhead, newsql := session.Statement.convertUpdateSql(sql) + oldhead, newsql := session.Statement.convertUpdateSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2377,7 +2713,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { var nStart int if len(args) > 0 { - if strings.Index(sql, "?") > -1 { + if strings.Index(sqlStr, "?") > -1 { nStart = strings.Count(oldhead, "?") } else { // only for pq, TODO: if any other databse? @@ -2418,7 +2754,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } @@ -2557,7 +2893,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - var sql, inSql string + var sqlStr, inSql string var inArgs []interface{} if table.Version != "" && session.Statement.checkVersion { if condition != "" { @@ -2575,7 +2911,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", @@ -2595,7 +2931,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), condition) @@ -2605,13 +2941,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, inArgs...) args = append(args, condiArgs...) - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } if table.Cacher != nil && session.Statement.UseCache { - //session.cacheUpdate(sql, args...) + //session.cacheUpdate(sqlStr, args...) table.Cacher.ClearIds(session.Statement.TableName()) table.Cacher.ClearBeans(session.Statement.TableName()) } @@ -2648,16 +2984,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return res.RowsAffected() } -func (session *Session) cacheDelete(sql string, args ...interface{}) error { +func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2749,16 +3085,16 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, ErrNeedDeletedCond } - sql := fmt.Sprintf("DELETE FROM %v WHERE %v", + sqlStr := fmt.Sprintf("DELETE FROM %v WHERE %v", session.Engine.Quote(session.Statement.TableName()), condition) args = append(session.Statement.Params, args...) if table.Cacher != nil && session.Statement.UseCache { - session.cacheDelete(sql, args...) + session.cacheDelete(sqlStr, args...) } - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } diff --git a/table.go b/table.go index 84ba6e53..a1ae7cdc 100644 --- a/table.go +++ b/table.go @@ -163,6 +163,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { if t == reflect.TypeOf(c_TIME_DEFAULT) { st = SQLType{DateTime, 0, 0} } else { + // TODO need to handle association struct st = SQLType{Text, 0, 0} } case reflect.Ptr: