1. correct use of 'sql' string clash with sql pacakage. 2. checked type conversion.

This commit is contained in:
Nash Tsai 2013-12-28 02:42:50 +08:00
parent 7bbabe72f0
commit 814036e258
1 changed files with 310 additions and 160 deletions

View File

@ -186,8 +186,8 @@ func (session *Session) Desc(colNames ...string) *Session {
session.Statement.OrderStr += ", " session.Statement.OrderStr += ", "
} }
newColNames := col2NewCols(colNames...) newColNames := col2NewCols(colNames...)
sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) sqlStr := strings.Join(newColNames, session.Engine.Quote(" DESC, "))
session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " DESC"
return session return session
} }
@ -197,8 +197,8 @@ func (session *Session) Asc(colNames ...string) *Session {
session.Statement.OrderStr += ", " session.Statement.OrderStr += ", "
} }
newColNames := col2NewCols(colNames...) newColNames := col2NewCols(colNames...)
sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) sqlStr := strings.Join(newColNames, session.Engine.Quote(" ASC, "))
session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " ASC"
return session return session
} }
@ -392,8 +392,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
} }
//Execute sql //Execute sql
func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(sql) rs, err := session.Db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -406,22 +406,22 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result,
return res, nil return res, nil
} }
func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session) sqlStr = filter.Do(sqlStr, session)
} }
session.Engine.LogSQL(sql) session.Engine.LogSQL(sqlStr)
session.Engine.LogSQL(args) session.Engine.LogSQL(args)
if session.IsAutoCommit { if session.IsAutoCommit {
return session.innerExec(sql, args...) return session.innerExec(sqlStr, args...)
} }
return session.Tx.Exec(sql, args...) return session.Tx.Exec(sqlStr, args...)
} }
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) {
err := session.newDb() err := session.newDb()
if err != nil { if err != nil {
return nil, err return nil, err
@ -431,7 +431,7 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error
defer session.Close() defer session.Close()
} }
return session.exec(sql, args...) return session.exec(sqlStr, args...)
} }
// this function create a table according a bean // this function create a table according a bean
@ -464,8 +464,8 @@ func (session *Session) CreateIndexes(bean interface{}) error {
} }
sqls := session.Statement.genIndexSQL() sqls := session.Statement.genIndexSQL()
for _, sql := range sqls { for _, sqlStr := range sqls {
_, err = session.exec(sql) _, err = session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
@ -487,8 +487,8 @@ func (session *Session) CreateUniques(bean interface{}) error {
} }
sqls := session.Statement.genUniqueSQL() sqls := session.Statement.genUniqueSQL()
for _, sql := range sqls { for _, sqlStr := range sqls {
_, err = session.exec(sql) _, err = session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
@ -497,9 +497,9 @@ func (session *Session) CreateUniques(bean interface{}) error {
} }
func (session *Session) createOneTable() error { func (session *Session) createOneTable() error {
sql := session.Statement.genCreateTableSQL() sqlStr := session.Statement.genCreateTableSQL()
session.Engine.LogDebug("create table sql: [", sql, "]") session.Engine.LogDebug("create table sql: [", sqlStr, "]")
_, err := session.exec(sql) _, err := session.exec(sqlStr)
return err return err
} }
@ -536,8 +536,8 @@ func (session *Session) DropIndexes(bean interface{}) error {
} }
sqls := session.Statement.genDelIndexSQL() sqls := session.Statement.genDelIndexSQL()
for _, sql := range sqls { for _, sqlStr := range sqls {
_, err = session.exec(sql) _, err = session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
@ -567,16 +567,16 @@ func (session *Session) DropTable(bean interface{}) error {
return errors.New("Unsupported type") return errors.New("Unsupported type")
} }
sql := session.Statement.genDropSQL() sqlStr := session.Statement.genDropSQL()
_, err = session.exec(sql) _, err = session.exec(sqlStr)
return err return err
} }
func (statement *Statement) convertIdSql(sql string) string { func (statement *Statement) convertIdSql(sqlStr string) string {
if statement.RefTable != nil { if statement.RefTable != nil {
col := statement.RefTable.PKColumn() col := statement.RefTable.PKColumn()
if col != nil { if col != nil {
sqls := splitNNoCase(sql, "from", 2) sqls := splitNNoCase(sqlStr, "from", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
return "" return ""
} }
@ -588,14 +588,14 @@ func (statement *Statement) convertIdSql(sql string) string {
return "" return ""
} }
func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" {
return false, ErrCacheFailed return false, ErrCacheFailed
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session) sqlStr = filter.Do(sqlStr, session)
} }
newsql := session.Statement.convertIdSql(sql) newsql := session.Statement.convertIdSql(sqlStr)
if newsql == "" { if newsql == "" {
return false, ErrCacheFailed return false, ErrCacheFailed
} }
@ -667,19 +667,19 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface
return false, nil return false, nil
} }
func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) {
if session.Statement.RefTable == nil || if session.Statement.RefTable == nil ||
session.Statement.RefTable.PrimaryKey == "" || session.Statement.RefTable.PrimaryKey == "" ||
indexNoCase(sql, "having") != -1 || indexNoCase(sqlStr, "having") != -1 ||
indexNoCase(sql, "group by") != -1 { indexNoCase(sqlStr, "group by") != -1 {
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session) sqlStr = filter.Do(sqlStr, session)
} }
newsql := session.Statement.convertIdSql(sql) newsql := session.Statement.convertIdSql(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
@ -867,41 +867,67 @@ func (session *Session) Get(bean interface{}) (bool, error) {
} }
session.Statement.Limit(1) session.Statement.Limit(1)
var sql string var sqlStr string
var args []interface{} var args []interface{}
session.Statement.RefTable = session.Engine.autoMap(bean) session.Statement.RefTable = session.Engine.autoMap(bean)
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
sql, args = session.Statement.genGetSql(bean) sqlStr, args = session.Statement.genGetSql(bean)
} else { } else {
sql = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache {
has, err := session.cacheGet(bean, sql, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return has, err return has, err
} }
} }
resultsSlice, err := session.query(sql, args...) var rawRows *sql.Rows
session.queryPreprocess(sqlStr, args...)
if session.IsAutoCommit {
stmt, err := session.Db.Prepare(sqlStr)
if err != nil {
return false, err
}
defer stmt.Close()
rawRows, err = stmt.Query(args...)
} else {
rawRows, err = session.Tx.Query(sqlStr, args...)
}
if err != nil { if err != nil {
return false, err return false, err
} }
if len(resultsSlice) < 1 { defer rawRows.Close()
if rawRows.Next() {
if fields, err := rawRows.Columns(); err == nil {
err = session.row2Bean(rawRows, fields, len(fields), bean)
}
return true, err
} else {
return false, nil return false, nil
} }
err = session.scanMapIntoStruct(bean, resultsSlice[0]) // resultsSlice, err := session.query(sqlStr, args...)
if err != nil { // if err != nil {
return true, err // return false, err
} // }
if len(resultsSlice) == 1 { // if len(resultsSlice) < 1 {
return true, nil // return false, nil
} else { // }
return true, errors.New("More than one record")
} // err = session.scanMapIntoStruct(bean, resultsSlice[0])
// if err != nil {
// return true, err
// }
// if len(resultsSlice) == 1 {
// return true, nil
// } else {
// return true, errors.New("More than one record")
// }
} }
// Count counts the records. bean's non-empty fields // Count counts the records. bean's non-empty fields
@ -917,16 +943,16 @@ func (session *Session) Count(bean interface{}) (int64, error) {
defer session.Close() defer session.Close()
} }
var sql string var sqlStr string
var args []interface{} var args []interface{}
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
sql, args = session.Statement.genCountSql(bean) sqlStr, args = session.Statement.genCountSql(bean)
} else { } else {
sql = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
} }
resultsSlice, err := session.query(sql, args...) resultsSlice, err := session.query(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -987,7 +1013,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
session.Statement.BeanArgs = args session.Statement.BeanArgs = args
} }
var sql string var sqlStr string
var args []interface{} var args []interface{}
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
var columnStr string = session.Statement.ColumnStr var columnStr string = session.Statement.ColumnStr
@ -997,46 +1023,94 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
session.Statement.attachInSql() session.Statement.attachInSql()
sql = session.Statement.genSelectSql(columnStr) sqlStr = session.Statement.genSelectSql(columnStr)
args = append(session.Statement.Params, session.Statement.BeanArgs...) args = append(session.Statement.Params, session.Statement.BeanArgs...)
} else { } else {
sql = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if table.Cacher != nil && if table.Cacher != nil &&
session.Statement.UseCache && session.Statement.UseCache &&
!session.Statement.IsDistinct { !session.Statement.IsDistinct {
err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return err return err
} }
session.Engine.LogWarn("Cache Find Failed") session.Engine.LogWarn("Cache Find Failed")
} }
resultsSlice, err := session.query(sql, args...) if sliceValue.Kind() != reflect.Map {
if err != nil { var rawRows *sql.Rows
return err
}
for i, results := range resultsSlice { session.queryPreprocess(sqlStr, args...)
var newValue reflect.Value // err = session.queryRows(&stmt, &rawRows, sqlStr, args...)
if sliceElementType.Kind() == reflect.Ptr { // if err != nil {
newValue = reflect.New(sliceElementType.Elem()) // return err
// }
// if stmt != nil {
// defer stmt.Close()
// }
// defer rawRows.Close()
if session.IsAutoCommit {
stmt, err := session.Db.Prepare(sqlStr)
if err != nil {
return err
}
defer stmt.Close()
rawRows, err = stmt.Query(args...)
} else { } else {
newValue = reflect.New(sliceElementType) rawRows, err = session.Tx.Query(sqlStr, args...)
} }
err := session.scanMapIntoStruct(newValue.Interface(), results)
if err != nil { if err != nil {
return err return err
} }
if sliceValue.Kind() == reflect.Slice { defer rawRows.Close()
fields, err := rawRows.Columns()
if err != nil {
return err
}
fieldsCount := len(fields)
for rawRows.Next() {
var newValue reflect.Value
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) newValue = reflect.New(sliceElementType.Elem())
} else { } else {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) newValue = reflect.New(sliceElementType)
}
err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface())
if err != nil {
return err
}
if sliceValue.Kind() == reflect.Slice {
if sliceElementType.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface())))
} else {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
}
}
}
} else {
resultsSlice, err := session.query(sqlStr, args...)
if err != nil {
return err
}
for i, results := range resultsSlice {
var newValue reflect.Value
if sliceElementType.Kind() == reflect.Ptr {
newValue = reflect.New(sliceElementType.Elem())
} else {
newValue = reflect.New(sliceElementType)
}
err := session.scanMapIntoStruct(newValue.Interface(), results)
if err != nil {
return err
} }
} else if sliceValue.Kind() == reflect.Map {
var key int64 var key int64
if table.PrimaryKey != "" { if table.PrimaryKey != "" {
x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64)
@ -1057,6 +1131,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return nil return nil
} }
func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error {
var err error
if session.IsAutoCommit {
*rawStmt, err = session.Db.Prepare(sqlStr)
if err != nil {
return err
}
*rawRows, err = (*rawStmt).Query(args...)
} else {
*rawRows, err = session.Tx.Query(sqlStr, args...)
}
return err
}
// Test if database is ok // Test if database is ok
func (session *Session) Ping() error { func (session *Session) Ping() error {
err := session.newDb() err := session.newDb()
@ -1080,8 +1168,8 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) {
if session.IsAutoClose { if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) sqlStr, args := session.Engine.dialect.ColumnCheckSql(tableName, colName)
results, err := session.query(sql, args...) results, err := session.query(sqlStr, args...)
return len(results) > 0, err return len(results) > 0, err
} }
@ -1094,8 +1182,8 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
if session.IsAutoClose { if session.IsAutoClose {
defer session.Close() defer session.Close()
} }
sql, args := session.Engine.dialect.TableCheckSql(tableName) sqlStr, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sql, args...) results, err := session.query(sqlStr, args...)
return len(results) > 0, err return len(results) > 0, err
} }
@ -1114,8 +1202,8 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo
} else { } else {
idx = indexName(tableName, idxName) idx = indexName(tableName, idxName)
} }
sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx)
results, err := session.query(sql, args...) results, err := session.query(sqlStr, args...)
return len(results) > 0, err return len(results) > 0, err
} }
@ -1149,8 +1237,8 @@ func (session *Session) addColumn(colName string) error {
} }
//fmt.Println(session.Statement.RefTable) //fmt.Println(session.Statement.RefTable)
col := session.Statement.RefTable.Columns[colName] col := session.Statement.RefTable.Columns[colName]
sql, args := session.Statement.genAddColumnStr(col) sqlStr, args := session.Statement.genAddColumnStr(col)
_, err = session.exec(sql, args...) _, err = session.exec(sqlStr, args...)
return err return err
} }
@ -1165,8 +1253,8 @@ func (session *Session) addIndex(tableName, idxName string) error {
} }
//fmt.Println(idxName) //fmt.Println(idxName)
cols := session.Statement.RefTable.Indexes[idxName].Cols cols := session.Statement.RefTable.Indexes[idxName].Cols
sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) sqlStr, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols)
_, err = session.exec(sql, args...) _, err = session.exec(sqlStr, args...)
return err return err
} }
@ -1181,8 +1269,8 @@ func (session *Session) addUnique(tableName, uqeName string) error {
} }
//fmt.Println(uqeName, session.Statement.RefTable.Uniques) //fmt.Println(uqeName, session.Statement.RefTable.Uniques)
cols := session.Statement.RefTable.Indexes[uqeName].Cols cols := session.Statement.RefTable.Indexes[uqeName].Cols
sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) sqlStr, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols)
_, err = session.exec(sql, args...) _, err = session.exec(sqlStr, args...)
return err return err
} }
@ -1200,8 +1288,8 @@ func (session *Session) dropAll() error {
for _, table := range session.Engine.Tables { for _, table := range session.Engine.Tables {
session.Statement.Init() session.Statement.Init()
session.Statement.RefTable = table session.Statement.RefTable = table
sql := session.Statement.genDropSQL() sqlStr := session.Statement.genDropSQL()
_, err := session.exec(sql) _, err := session.exec(sqlStr)
if err != nil { if err != nil {
return err return err
} }
@ -1306,7 +1394,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
continue continue
} }
aa := reflect.TypeOf(rawValue.Interface()) rawValueType := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface())
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
@ -1318,7 +1406,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
if aa.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
hasAssigned = true hasAssigned = true
x := reflect.New(fieldType) x := reflect.New(fieldType)
err := json.Unmarshal([]byte(vv.String()), x.Interface()) err := json.Unmarshal([]byte(vv.String()), x.Interface())
@ -1329,38 +1417,40 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch aa.Kind() { switch rawValueType.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch aa.Elem().Kind() { switch rawValueType.Elem().Kind() {
case reflect.Uint8: case reflect.Uint8:
hasAssigned = true if fieldType.Elem().Kind() == reflect.Uint8 {
fieldValue.Set(rawValue) hasAssigned = true
fieldValue.Set(vv)
}
} }
} }
case reflect.String: case reflect.String:
if aa.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
hasAssigned = true hasAssigned = true
fieldValue.SetString(vv.String()) fieldValue.SetString(vv.String())
} }
case reflect.Bool: case reflect.Bool:
if aa.Kind() == reflect.Bool { if rawValueType.Kind() == reflect.Bool {
hasAssigned = true hasAssigned = true
fieldValue.SetBool(vv.Bool()) fieldValue.SetBool(vv.Bool())
} }
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
switch aa.Kind() { switch rawValueType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
hasAssigned = true hasAssigned = true
fieldValue.SetInt(vv.Int()) fieldValue.SetInt(vv.Int())
} }
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
switch aa.Kind() { switch rawValueType.Kind() {
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
hasAssigned = true hasAssigned = true
fieldValue.SetFloat(vv.Float()) fieldValue.SetFloat(vv.Float())
} }
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
switch aa.Kind() { switch rawValueType.Kind() {
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
hasAssigned = true hasAssigned = true
fieldValue.SetUint(vv.Uint()) fieldValue.SetUint(vv.Uint())
@ -1368,7 +1458,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
//Currently only support Time type //Currently only support Time type
case reflect.Struct: case reflect.Struct:
if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { if fieldType == reflect.TypeOf(c_TIME_DEFAULT) {
if aa == reflect.TypeOf(c_TIME_DEFAULT) { if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) {
hasAssigned = true hasAssigned = true
fieldValue.Set(rawValue) fieldValue.Set(rawValue)
} }
@ -1407,46 +1497,95 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
//typeStr := fieldType.String() //typeStr := fieldType.String()
switch fieldType { switch fieldType {
// following types case matching ptr's native type, therefore assign ptr directly // following types case matching ptr's native type, therefore assign ptr directly
case reflect.TypeOf(&c_EMPTY_STRING), reflect.TypeOf(&c_BOOL_DEFAULT), reflect.TypeOf(&c_TIME_DEFAULT), case reflect.TypeOf(&c_EMPTY_STRING):
reflect.TypeOf(&c_FLOAT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT), reflect.TypeOf(&c_INT64_DEFAULT): if rawValueType.Kind() == reflect.String {
hasAssigned = true x := vv.String()
fieldValue.Set(reflect.ValueOf(&rawValue)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_BOOL_DEFAULT):
if rawValueType.Kind() == reflect.Bool {
x := vv.Bool()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_TIME_DEFAULT):
if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) {
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&rawValue))
}
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
if rawValueType.Kind() == reflect.Float64 {
x := vv.Float()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_UINT64_DEFAULT):
if rawValueType.Kind() == reflect.Int64 {
var x uint64 = uint64(vv.Int())
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_INT64_DEFAULT):
if rawValueType.Kind() == reflect.Int64 {
x := vv.Int()
hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_FLOAT32_DEFAULT): case reflect.TypeOf(&c_FLOAT32_DEFAULT):
var x float32 = float32(vv.Float()) if rawValueType.Kind() == reflect.Float64 {
hasAssigned = true var x float32 = float32(vv.Float())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_INT_DEFAULT): case reflect.TypeOf(&c_INT_DEFAULT):
var x int = int(vv.Int()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x int = int(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_INT32_DEFAULT): case reflect.TypeOf(&c_INT32_DEFAULT):
var x int32 = int32(vv.Int()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x int32 = int32(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_INT8_DEFAULT): case reflect.TypeOf(&c_INT8_DEFAULT):
var x int8 = int8(vv.Int()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x int8 = int8(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_INT16_DEFAULT): case reflect.TypeOf(&c_INT16_DEFAULT):
var x int16 = int16(vv.Int()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x int16 = int16(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_UINT_DEFAULT): case reflect.TypeOf(&c_UINT_DEFAULT):
var x uint = uint(vv.Uint()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x uint = uint(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_UINT32_DEFAULT): case reflect.TypeOf(&c_UINT32_DEFAULT):
var x uint32 = uint32(vv.Uint()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x uint32 = uint32(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_UINT8_DEFAULT): case reflect.TypeOf(&c_UINT8_DEFAULT):
var x uint8 = uint8(vv.Uint()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x uint8 = uint8(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_UINT16_DEFAULT): case reflect.TypeOf(&c_UINT16_DEFAULT):
var x uint16 = uint16(vv.Uint()) if rawValueType.Kind() == reflect.Int64 {
hasAssigned = true var x uint16 = uint16(vv.Int())
fieldValue.Set(reflect.ValueOf(&x)) hasAssigned = true
fieldValue.Set(reflect.ValueOf(&x))
}
case reflect.TypeOf(&c_COMPLEX64_DEFAULT): case reflect.TypeOf(&c_COMPLEX64_DEFAULT):
var x complex64 var x complex64
err := json.Unmarshal([]byte(vv.String()), &x) err := json.Unmarshal([]byte(vv.String()), &x)
@ -1485,22 +1624,33 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
} }
func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) queryPreprocess(sqlStr string, paramStr ...interface{}) {
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session) sqlStr = filter.Do(sqlStr, session)
} }
session.Engine.LogSQL(sql) session.Engine.LogSQL(sqlStr)
session.Engine.LogSQL(paramStr)
}
func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
// !nashtsai! TODO calling session.queryPreprocess with cause error
// session.queryPreprocess(sqlStr, paramStr...)
for _, filter := range session.Engine.Filters {
sqlStr = filter.Do(sqlStr, session)
}
session.Engine.LogSQL(sqlStr)
session.Engine.LogSQL(paramStr) session.Engine.LogSQL(paramStr)
if session.IsAutoCommit { if session.IsAutoCommit {
return query(session.Db, sql, paramStr...) return query(session.Db, sqlStr, paramStr...)
} }
return txQuery(session.Tx, sql, paramStr...) return txQuery(session.Tx, sqlStr, paramStr...)
} }
func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { func txQuery(tx *sql.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
rows, err := tx.Query(sql, params...) rows, err := tx.Query(sqlStr, params...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1509,8 +1659,8 @@ func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[
return rows2maps(rows) return rows2maps(rows)
} }
func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { func query(db *sql.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
s, err := db.Prepare(sql) s, err := db.Prepare(sqlStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1525,7 +1675,7 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st
} }
// Exec a raw sql and return records as []map[string][]byte // Exec a raw sql and return records as []map[string][]byte
func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) Query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
err = session.newDb() err = session.newDb()
if err != nil { if err != nil {
return nil, err return nil, err
@ -1535,7 +1685,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice
defer session.Close() defer session.Close()
} }
return session.query(sql, paramStr...) return session.query(sqlStr, paramStr...)
} }
// insert one or more beans // insert one or more beans
@ -2310,7 +2460,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces := strings.Repeat("?, ", len(colNames))
colPlaces = colPlaces[0 : len(colPlaces)-2] colPlaces = colPlaces[0 : len(colPlaces)-2]
sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
@ -2351,7 +2501,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" {
res, err := session.exec(sql, args...) res, err := session.exec(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} else { } else {
@ -2395,8 +2545,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return res.RowsAffected() return res.RowsAffected()
} else { } else {
sql = sql + " RETURNING (id)" sqlStr = sqlStr + " RETURNING (id)"
res, err := session.query(sql, args...) res, err := session.query(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} else { } else {
@ -2458,11 +2608,11 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return session.innerInsert(bean) return session.innerInsert(bean)
} }
func (statement *Statement) convertUpdateSql(sql string) (string, string) { func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) {
if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" {
return "", "" return "", ""
} }
sqls := splitNNoCase(sql, "where", 2) sqls := splitNNoCase(sqlStr, "where", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
if len(sqls) == 1 { if len(sqls) == 1 {
return sqls[0], fmt.Sprintf("SELECT %v FROM %v", return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
@ -2505,12 +2655,12 @@ func (session *Session) cacheInsert(tables ...string) error {
return nil return nil
} }
func (session *Session) cacheUpdate(sql string, args ...interface{}) error { func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" {
return ErrCacheFailed return ErrCacheFailed
} }
oldhead, newsql := session.Statement.convertUpdateSql(sql) oldhead, newsql := session.Statement.convertUpdateSql(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
@ -2521,7 +2671,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error {
var nStart int var nStart int
if len(args) > 0 { if len(args) > 0 {
if strings.Index(sql, "?") > -1 { if strings.Index(sqlStr, "?") > -1 {
nStart = strings.Count(oldhead, "?") nStart = strings.Count(oldhead, "?")
} else { } else {
// only for pq, TODO: if any other databse? // only for pq, TODO: if any other databse?
@ -2562,7 +2712,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error {
for _, id := range ids { for _, id := range ids {
if bean := cacher.GetBean(tableName, id); bean != nil { if bean := cacher.GetBean(tableName, id); bean != nil {
sqls := splitNNoCase(sql, "where", 2) sqls := splitNNoCase(sqlStr, "where", 2)
if len(sqls) == 0 || len(sqls) > 2 { if len(sqls) == 0 || len(sqls) > 2 {
return ErrCacheFailed return ErrCacheFailed
} }
@ -2701,7 +2851,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
var sql, inSql string var sqlStr, inSql string
var inArgs []interface{} var inArgs []interface{}
if table.Version != "" && session.Statement.checkVersion { if table.Version != "" && session.Statement.checkVersion {
if condition != "" { if condition != "" {
@ -2719,7 +2869,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v",
session.Engine.Quote(session.Statement.TableName()), session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
@ -2739,7 +2889,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sql = fmt.Sprintf("UPDATE %v SET %v %v", sqlStr = fmt.Sprintf("UPDATE %v SET %v %v",
session.Engine.Quote(session.Statement.TableName()), session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
condition) condition)
@ -2749,13 +2899,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
args = append(args, inArgs...) args = append(args, inArgs...)
args = append(args, condiArgs...) args = append(args, condiArgs...)
res, err := session.exec(sql, args...) res, err := session.exec(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if table.Cacher != nil && session.Statement.UseCache { if table.Cacher != nil && session.Statement.UseCache {
//session.cacheUpdate(sql, args...) //session.cacheUpdate(sqlStr, args...)
table.Cacher.ClearIds(session.Statement.TableName()) table.Cacher.ClearIds(session.Statement.TableName())
table.Cacher.ClearBeans(session.Statement.TableName()) table.Cacher.ClearBeans(session.Statement.TableName())
} }
@ -2792,16 +2942,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return res.RowsAffected() return res.RowsAffected()
} }
func (session *Session) cacheDelete(sql string, args ...interface{}) error { func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" {
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session) sqlStr = filter.Do(sqlStr, session)
} }
newsql := session.Statement.convertIdSql(sql) newsql := session.Statement.convertIdSql(sqlStr)
if newsql == "" { if newsql == "" {
return ErrCacheFailed return ErrCacheFailed
} }
@ -2893,16 +3043,16 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
return 0, ErrNeedDeletedCond return 0, ErrNeedDeletedCond
} }
sql := fmt.Sprintf("DELETE FROM %v WHERE %v", sqlStr := fmt.Sprintf("DELETE FROM %v WHERE %v",
session.Engine.Quote(session.Statement.TableName()), condition) session.Engine.Quote(session.Statement.TableName()), condition)
args = append(session.Statement.Params, args...) args = append(session.Statement.Params, args...)
if table.Cacher != nil && session.Statement.UseCache { if table.Cacher != nil && session.Statement.UseCache {
session.cacheDelete(sql, args...) session.cacheDelete(sqlStr, args...)
} }
res, err := session.exec(sql, args...) res, err := session.exec(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }