diff --git a/session.go b/session.go index 9d86b1de..11164215 100644 --- a/session.go +++ b/session.go @@ -186,8 +186,8 @@ func (session *Session) Desc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " DESC" return session } @@ -197,8 +197,8 @@ func (session *Session) Asc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " ASC" return session } @@ -392,8 +392,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b } //Execute sql -func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(sql) +func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(sqlStr) if err != nil { return nil, err } @@ -406,22 +406,22 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, return res, nil } -func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - session.Engine.LogSQL(sql) + session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(args) if session.IsAutoCommit { - return session.innerExec(sql, args...) + return session.innerExec(sqlStr, args...) } - return session.Tx.Exec(sql, args...) + return session.Tx.Exec(sqlStr, args...) } // Exec raw sql -func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { err := session.newDb() if err != nil { return nil, err @@ -431,7 +431,7 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error defer session.Close() } - return session.exec(sql, args...) + return session.exec(sqlStr, args...) } // this function create a table according a bean @@ -464,8 +464,8 @@ func (session *Session) CreateIndexes(bean interface{}) error { } sqls := session.Statement.genIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -487,8 +487,8 @@ func (session *Session) CreateUniques(bean interface{}) error { } sqls := session.Statement.genUniqueSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -497,9 +497,9 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createOneTable() error { - sql := session.Statement.genCreateTableSQL() - session.Engine.LogDebug("create table sql: [", sql, "]") - _, err := session.exec(sql) + sqlStr := session.Statement.genCreateTableSQL() + session.Engine.LogDebug("create table sql: [", sqlStr, "]") + _, err := session.exec(sqlStr) return err } @@ -536,8 +536,8 @@ func (session *Session) DropIndexes(bean interface{}) error { } sqls := session.Statement.genDelIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -567,16 +567,16 @@ func (session *Session) DropTable(bean interface{}) error { return errors.New("Unsupported type") } - sql := session.Statement.genDropSQL() - _, err = session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err = session.exec(sqlStr) return err } -func (statement *Statement) convertIdSql(sql string) string { +func (statement *Statement) convertIdSql(sqlStr string) string { if statement.RefTable != nil { col := statement.RefTable.PKColumn() if col != nil { - sqls := splitNNoCase(sql, "from", 2) + sqls := splitNNoCase(sqlStr, "from", 2) if len(sqls) != 2 { return "" } @@ -588,14 +588,14 @@ func (statement *Statement) convertIdSql(sql string) string { return "" } -func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { +func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return false, ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return false, ErrCacheFailed } @@ -667,19 +667,19 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface return false, nil } -func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { +func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" || - indexNoCase(sql, "having") != -1 || - indexNoCase(sql, "group by") != -1 { + indexNoCase(sqlStr, "having") != -1 || + indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -867,41 +867,67 @@ func (session *Session) Get(bean interface{}) (bool, error) { } session.Statement.Limit(1) - var sql string + var sqlStr string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) + sqlStr, args = session.Statement.genGetSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { - has, err := session.cacheGet(bean, sql, args...) + has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err } } - resultsSlice, err := session.query(sql, args...) + var rawRows *sql.Rows + session.queryPreprocess(sqlStr, args...) + if session.IsAutoCommit { + stmt, err := session.Db.Prepare(sqlStr) + if err != nil { + return false, err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) + } else { + rawRows, err = session.Tx.Query(sqlStr, args...) + } if err != nil { return false, err } - if len(resultsSlice) < 1 { + defer rawRows.Close() + + if rawRows.Next() { + if fields, err := rawRows.Columns(); err == nil { + err = session.row2Bean(rawRows, fields, len(fields), bean) + } + return true, err + } else { return false, nil } - err = session.scanMapIntoStruct(bean, resultsSlice[0]) - if err != nil { - return true, err - } - if len(resultsSlice) == 1 { - return true, nil - } else { - return true, errors.New("More than one record") - } + // resultsSlice, err := session.query(sqlStr, args...) + // if err != nil { + // return false, err + // } + // if len(resultsSlice) < 1 { + // return false, nil + // } + + // err = session.scanMapIntoStruct(bean, resultsSlice[0]) + // if err != nil { + // return true, err + // } + // if len(resultsSlice) == 1 { + // return true, nil + // } else { + // return true, errors.New("More than one record") + // } } // Count counts the records. bean's non-empty fields @@ -917,16 +943,16 @@ func (session *Session) Count(bean interface{}) (int64, error) { defer session.Close() } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - sql, args = session.Statement.genCountSql(bean) + sqlStr, args = session.Statement.genCountSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } - resultsSlice, err := session.query(sql, args...) + resultsSlice, err := session.query(sqlStr, args...) if err != nil { return 0, err } @@ -987,7 +1013,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.BeanArgs = args } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr @@ -997,46 +1023,94 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.attachInSql() - sql = session.Statement.genSelectSql(columnStr) + sqlStr = session.Statement.genSelectSql(columnStr) args = append(session.Statement.Params, session.Statement.BeanArgs...) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if table.Cacher != nil && session.Statement.UseCache && !session.Statement.IsDistinct { - err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) + err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } session.Engine.LogWarn("Cache Find Failed") } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return err - } + if sliceValue.Kind() != reflect.Map { + var rawRows *sql.Rows - for i, results := range resultsSlice { - var newValue reflect.Value - if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) + session.queryPreprocess(sqlStr, args...) + // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) + // if err != nil { + // return err + // } + // if stmt != nil { + // defer stmt.Close() + // } + // defer rawRows.Close() + + if session.IsAutoCommit { + stmt, err := session.Db.Prepare(sqlStr) + if err != nil { + return err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) } else { - newValue = reflect.New(sliceElementType) + rawRows, err = session.Tx.Query(sqlStr, args...) } - err := session.scanMapIntoStruct(newValue.Interface(), results) if err != nil { return err } - if sliceValue.Kind() == reflect.Slice { + defer rawRows.Close() + + fields, err := rawRows.Columns() + if err != nil { + return err + } + + fieldsCount := len(fields) + + for rawRows.Next() { + var newValue reflect.Value if sliceElementType.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + newValue = reflect.New(sliceElementType.Elem()) } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + newValue = reflect.New(sliceElementType) + } + err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface()) + if err != nil { + return err + } + if sliceValue.Kind() == reflect.Slice { + if sliceElementType.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + } + } + } else { + resultsSlice, err := session.query(sqlStr, args...) + if err != nil { + return err + } + + for i, results := range resultsSlice { + var newValue reflect.Value + if sliceElementType.Kind() == reflect.Ptr { + newValue = reflect.New(sliceElementType.Elem()) + } else { + newValue = reflect.New(sliceElementType) + } + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err } - } else if sliceValue.Kind() == reflect.Map { var key int64 if table.PrimaryKey != "" { x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) @@ -1057,6 +1131,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } +func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { + var err error + if session.IsAutoCommit { + *rawStmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return err + } + *rawRows, err = (*rawStmt).Query(args...) + } else { + *rawRows, err = session.Tx.Query(sqlStr, args...) + } + return err +} + // Test if database is ok func (session *Session) Ping() error { err := session.newDb() @@ -1080,8 +1168,8 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1094,8 +1182,8 @@ func (session *Session) isTableExist(tableName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1114,8 +1202,8 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo } else { idx = indexName(tableName, idxName) } - sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1149,8 +1237,8 @@ func (session *Session) addColumn(colName string) error { } //fmt.Println(session.Statement.RefTable) col := session.Statement.RefTable.Columns[colName] - sql, args := session.Statement.genAddColumnStr(col) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddColumnStr(col) + _, err = session.exec(sqlStr, args...) return err } @@ -1165,8 +1253,8 @@ func (session *Session) addIndex(tableName, idxName string) error { } //fmt.Println(idxName) cols := session.Statement.RefTable.Indexes[idxName].Cols - sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1181,8 +1269,8 @@ func (session *Session) addUnique(tableName, uqeName string) error { } //fmt.Println(uqeName, session.Statement.RefTable.Uniques) cols := session.Statement.RefTable.Indexes[uqeName].Cols - sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1200,8 +1288,8 @@ func (session *Session) dropAll() error { for _, table := range session.Engine.Tables { session.Statement.Init() session.Statement.RefTable = table - sql := session.Statement.genDropSQL() - _, err := session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err := session.exec(sqlStr) if err != nil { return err } @@ -1306,7 +1394,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in continue } - aa := reflect.TypeOf(rawValue.Interface()) + rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) fieldType := fieldValue.Type() @@ -1318,7 +1406,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in switch fieldType.Kind() { case reflect.Complex64, reflect.Complex128: - if aa.Kind() == reflect.String { + if rawValueType.Kind() == reflect.String { hasAssigned = true x := reflect.New(fieldType) err := json.Unmarshal([]byte(vv.String()), x.Interface()) @@ -1329,38 +1417,40 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in fieldValue.Set(x.Elem()) } case reflect.Slice, reflect.Array: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Slice, reflect.Array: - switch aa.Elem().Kind() { + switch rawValueType.Elem().Kind() { case reflect.Uint8: - hasAssigned = true - fieldValue.Set(rawValue) + if fieldType.Elem().Kind() == reflect.Uint8 { + hasAssigned = true + fieldValue.Set(vv) + } } } case reflect.String: - if aa.Kind() == reflect.String { + if rawValueType.Kind() == reflect.String { hasAssigned = true fieldValue.SetString(vv.String()) } case reflect.Bool: - if aa.Kind() == reflect.Bool { + if rawValueType.Kind() == reflect.Bool { hasAssigned = true fieldValue.SetBool(vv.Bool()) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: hasAssigned = true fieldValue.SetInt(vv.Int()) } case reflect.Float32, reflect.Float64: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Float32, reflect.Float64: hasAssigned = true fieldValue.SetFloat(vv.Float()) } case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: hasAssigned = true fieldValue.SetUint(vv.Uint()) @@ -1368,7 +1458,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in //Currently only support Time type case reflect.Struct: if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { - if aa == reflect.TypeOf(c_TIME_DEFAULT) { + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { hasAssigned = true fieldValue.Set(rawValue) } @@ -1407,46 +1497,95 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in //typeStr := fieldType.String() switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly - case reflect.TypeOf(&c_EMPTY_STRING), reflect.TypeOf(&c_BOOL_DEFAULT), reflect.TypeOf(&c_TIME_DEFAULT), - reflect.TypeOf(&c_FLOAT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT), reflect.TypeOf(&c_INT64_DEFAULT): - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&rawValue)) + case reflect.TypeOf(&c_EMPTY_STRING): + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_BOOL_DEFAULT): + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_TIME_DEFAULT): + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&rawValue)) + } + case reflect.TypeOf(&c_FLOAT64_DEFAULT): + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint64 = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_FLOAT32_DEFAULT): - var x float32 = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Float64 { + var x float32 = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT_DEFAULT): - var x int = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT32_DEFAULT): - var x int32 = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int32 = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT8_DEFAULT): - var x int8 = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int8 = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT16_DEFAULT): - var x int16 = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int16 = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT_DEFAULT): - var x uint = uint(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT32_DEFAULT): - var x uint32 = uint32(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint32 = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT8_DEFAULT): - var x uint8 = uint8(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint8 = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT16_DEFAULT): - var x uint16 = uint16(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint16 = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_COMPLEX64_DEFAULT): var x complex64 err := json.Unmarshal([]byte(vv.String()), &x) @@ -1485,22 +1624,33 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in } -func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) queryPreprocess(sqlStr string, paramStr ...interface{}) { for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - session.Engine.LogSQL(sql) + session.Engine.LogSQL(sqlStr) + session.Engine.LogSQL(paramStr) +} + +func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + // !nashtsai! TODO calling session.queryPreprocess with cause error + // session.queryPreprocess(sqlStr, paramStr...) + for _, filter := range session.Engine.Filters { + sqlStr = filter.Do(sqlStr, session) + } + + session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(paramStr) if session.IsAutoCommit { - return query(session.Db, sql, paramStr...) + return query(session.Db, sqlStr, paramStr...) } - return txQuery(session.Tx, sql, paramStr...) + return txQuery(session.Tx, sqlStr, paramStr...) } -func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - rows, err := tx.Query(sql, params...) +func txQuery(tx *sql.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + rows, err := tx.Query(sqlStr, params...) if err != nil { return nil, err } @@ -1509,8 +1659,8 @@ func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[ return rows2maps(rows) } -func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - s, err := db.Prepare(sql) +func query(db *sql.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + s, err := db.Prepare(sqlStr) if err != nil { return nil, err } @@ -1525,7 +1675,7 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st } // Exec a raw sql and return records as []map[string][]byte -func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) Query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { err = session.newDb() if err != nil { return nil, err @@ -1535,7 +1685,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice defer session.Close() } - return session.query(sql, paramStr...) + return session.query(sqlStr, paramStr...) } // insert one or more beans @@ -2310,7 +2460,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces = colPlaces[0 : len(colPlaces)-2] - sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", + sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", session.Engine.QuoteStr(), session.Statement.TableName(), session.Engine.QuoteStr(), @@ -2351,7 +2501,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } else { @@ -2395,8 +2545,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } else { - sql = sql + " RETURNING (id)" - res, err := session.query(sql, args...) + sqlStr = sqlStr + " RETURNING (id)" + res, err := session.query(sqlStr, args...) if err != nil { return 0, err } else { @@ -2458,11 +2608,11 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (statement *Statement) convertUpdateSql(sql string) (string, string) { +func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) { if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { return "", "" } - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", @@ -2505,12 +2655,12 @@ func (session *Session) cacheInsert(tables ...string) error { return nil } -func (session *Session) cacheUpdate(sql string, args ...interface{}) error { +func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return ErrCacheFailed } - oldhead, newsql := session.Statement.convertUpdateSql(sql) + oldhead, newsql := session.Statement.convertUpdateSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2521,7 +2671,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { var nStart int if len(args) > 0 { - if strings.Index(sql, "?") > -1 { + if strings.Index(sqlStr, "?") > -1 { nStart = strings.Count(oldhead, "?") } else { // only for pq, TODO: if any other databse? @@ -2562,7 +2712,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } @@ -2701,7 +2851,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - var sql, inSql string + var sqlStr, inSql string var inArgs []interface{} if table.Version != "" && session.Statement.checkVersion { if condition != "" { @@ -2719,7 +2869,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", @@ -2739,7 +2889,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), condition) @@ -2749,13 +2899,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, inArgs...) args = append(args, condiArgs...) - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } if table.Cacher != nil && session.Statement.UseCache { - //session.cacheUpdate(sql, args...) + //session.cacheUpdate(sqlStr, args...) table.Cacher.ClearIds(session.Statement.TableName()) table.Cacher.ClearBeans(session.Statement.TableName()) } @@ -2792,16 +2942,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return res.RowsAffected() } -func (session *Session) cacheDelete(sql string, args ...interface{}) error { +func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2893,16 +3043,16 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, ErrNeedDeletedCond } - sql := fmt.Sprintf("DELETE FROM %v WHERE %v", + sqlStr := fmt.Sprintf("DELETE FROM %v WHERE %v", session.Engine.Quote(session.Statement.TableName()), condition) args = append(session.Statement.Params, args...) if table.Cacher != nil && session.Statement.UseCache { - session.cacheDelete(sql, args...) + session.cacheDelete(sqlStr, args...) } - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err }