From 7bbabe72f0f4bb318ad5c0022fce6bcfcfbab29e Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 27 Dec 2013 02:14:30 +0800 Subject: [PATCH 1/5] add Session.row2Bean --- helpers.go | 73 +++++++++++ rows.go | 26 ++-- session.go | 358 ++++++++++++++++++++++++++++++++++++++++------------- table.go | 1 + 4 files changed, 361 insertions(+), 97 deletions(-) diff --git a/helpers.go b/helpers.go index 307353c2..96f118f2 100644 --- a/helpers.go +++ b/helpers.go @@ -1,8 +1,12 @@ package xorm import ( + "database/sql" + "fmt" "reflect" + "strconv" "strings" + "time" ) func indexNoCase(s, sep string) int { @@ -61,3 +65,72 @@ func sliceEq(left, right []string) bool { return true } + +func value2Bytes(rawValue *reflect.Value) (data []byte, err error) { + + aa := reflect.TypeOf((*rawValue).Interface()) + vv := reflect.ValueOf((*rawValue).Interface()) + + var str string + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + data = []byte(str) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + data = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + data = []byte(str) + case reflect.String: + str = vv.String() + data = []byte(str) + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + data = rawValue.Interface().([]byte) + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + //时间类型 + case reflect.Struct: + if aa == reflect.TypeOf(c_TIME_DEFAULT) { + str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) + data = []byte(str) + } else { + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + data = []byte(str) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + data = []byte(str) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) + } + return +} + +func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2map(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} diff --git a/rows.go b/rows.go index 7aeb87f0..0ac6c956 100644 --- a/rows.go +++ b/rows.go @@ -9,12 +9,13 @@ import ( type Rows struct { NoTypeCheck bool - session *Session - stmt *sql.Stmt - rows *sql.Rows - fields []string - beanType reflect.Type - lastError error + session *Session + stmt *sql.Stmt + rows *sql.Rows + fields []string + fieldsCount int + beanType reflect.Type + lastError error } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -66,6 +67,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { defer rows.Close() return nil, err } + rows.fieldsCount = len(rows.fields) return rows, nil } @@ -97,11 +99,13 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - result, err := row2map(rows.rows, rows.fields) // !nashtsai! TODO remove row2map then scanMapIntoStruct conversation for better performance - if err == nil { - err = rows.session.scanMapIntoStruct(bean, result) - } - return err + return rows.session.row2Bean(rows.rows, rows.fields, rows.fieldsCount, bean) + + // result, err := row2map(rows.rows, rows.fields) // !nashtsai! TODO remove row2map then scanMapIntoStruct conversation for better performance + // if err == nil { + // err = rows.session.scanMapIntoStruct(bean, result) + // } + // return err } // // Columns returns the column names. Columns returns an error if the rows are closed, or if the rows are from QueryRow and there was a deferred error. diff --git a/session.go b/session.go index a15852ba..9d86b1de 100644 --- a/session.go +++ b/session.go @@ -1220,8 +1220,6 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err return nil, err } - // !nashtsai! TODO optimization for query performance, where current process has gone from - // sql driver converted type back to []bytes then to ORM's fields for ii, key := range fields { rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) @@ -1230,72 +1228,261 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err //fmt.Println("ignore ...", key, rawValue) continue } - aa := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - var str string - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - result[key] = []byte(str) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - result[key] = []byte(str) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - result[key] = []byte(str) - case reflect.String: - str = vv.String() - result[key] = []byte(str) - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - result[key] = rawValue.Interface().([]byte) - str = string(result[key]) - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - //时间类型 - case reflect.Struct: - if aa == reflect.TypeOf(c_TIME_DEFAULT) { - str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) - result[key] = []byte(str) - } else { - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - result[key] = []byte(str) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - result[key] = []byte(str) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) + + if data, err := value2Bytes(&rawValue); err == nil { + result[key] = data + } else { + return nil, err // !nashtsai! REVIEW, should return err or just error log? } } return result, nil } -func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err +func (session *Session) getField(dataStruct *reflect.Value, key string, table *Table) *reflect.Value { + + key = strings.ToLower(key) + if _, ok := table.Columns[key]; !ok { + session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) + return nil } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err + col := table.Columns[key] + fieldName := col.FieldName + fieldPath := strings.Split(fieldName, ".") + var fieldValue reflect.Value + if len(fieldPath) > 2 { + session.Engine.LogError("Unsupported mutliderive", fieldName) + return nil + } else if len(fieldPath) == 2 { + parentField := dataStruct.FieldByName(fieldPath[0]) + if parentField.IsValid() { + fieldValue = parentField.FieldByName(fieldPath[1]) } - resultsSlice = append(resultsSlice, result) + } else { + fieldValue = dataStruct.FieldByName(fieldName) + } + if !fieldValue.IsValid() || !fieldValue.CanSet() { + session.Engine.LogWarn("table %v's column %v is not valid or cannot set", + table.Name, key) + return nil + } + return &fieldValue +} + +func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { + + dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + if dataStruct.Kind() != reflect.Struct { + return errors.New("Expected a pointer to a struct") } - return resultsSlice, nil + table := session.Engine.autoMapType(rType(bean)) + + var scanResultContainers []interface{} + for i := 0; i < fieldsCount; i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := rows.Scan(scanResultContainers...); err != nil { + return err + } + + for ii, key := range fields { + if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { + + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + //fmt.Println("ignore ...", key, rawValue) + continue + } + + if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + structConvert.FromDB(data) + } else { + session.Engine.LogError(err) + } + continue + } + + aa := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + + fieldType := fieldValue.Type() + + //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) + + hasAssigned := false + + switch fieldType.Kind() { + + case reflect.Complex64, reflect.Complex128: + if aa.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } + case reflect.Slice, reflect.Array: + switch aa.Kind() { + case reflect.Slice, reflect.Array: + switch aa.Elem().Kind() { + case reflect.Uint8: + hasAssigned = true + fieldValue.Set(rawValue) + } + } + case reflect.String: + if aa.Kind() == reflect.String { + hasAssigned = true + fieldValue.SetString(vv.String()) + } + case reflect.Bool: + if aa.Kind() == reflect.Bool { + hasAssigned = true + fieldValue.SetBool(vv.Bool()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + hasAssigned = true + fieldValue.SetInt(vv.Int()) + } + case reflect.Float32, reflect.Float64: + switch aa.Kind() { + case reflect.Float32, reflect.Float64: + hasAssigned = true + fieldValue.SetFloat(vv.Float()) + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + switch aa.Kind() { + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + hasAssigned = true + fieldValue.SetUint(vv.Uint()) + } + //Currently only support Time type + case reflect.Struct: + if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { + if aa == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + fieldValue.Set(rawValue) + } + } + // else if session.Statement.UseCascade { // TODO + // table := session.Engine.autoMapType(fieldValue.Type()) + // if table != nil { + // x, err := strconv.ParseInt(string(data), 10, 64) + // if err != nil { + // return errors.New("arg " + key + " as int: " + er1r.Error()) + // } + // if x != 0 { + // // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // // property to be fetched lazily + // structInter := reflect.New(fieldValue.Type()) + // newsession := session.Engine.NewSession() + // defer newsession.Close() + // has, err := newsession.Id(x).Get(structInter.Interface()) + // if err != nil { + // return err + // } + // if has { + // v = structInter.Elem().Interface() + // fieldValue.Set(reflect.ValueOf(v)) + // } else { + // return errors.New("cascade obj is not exist!") + // } + // } + // } else { + // session.Engine.LogError("unsupported struct type in Scan: ", fieldValue.Type().String()) + // } + // } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + //typeStr := fieldType.String() + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case reflect.TypeOf(&c_EMPTY_STRING), reflect.TypeOf(&c_BOOL_DEFAULT), reflect.TypeOf(&c_TIME_DEFAULT), + reflect.TypeOf(&c_FLOAT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT), reflect.TypeOf(&c_INT64_DEFAULT): + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&rawValue)) + case reflect.TypeOf(&c_FLOAT32_DEFAULT): + var x float32 = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_INT_DEFAULT): + var x int = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_INT32_DEFAULT): + var x int32 = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_INT8_DEFAULT): + var x int8 = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_INT16_DEFAULT): + var x int16 = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_UINT_DEFAULT): + var x uint = uint(vv.Uint()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_UINT32_DEFAULT): + var x uint32 = uint32(vv.Uint()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_UINT8_DEFAULT): + var x uint8 = uint8(vv.Uint()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_UINT16_DEFAULT): + var x uint16 = uint16(vv.Uint()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + case reflect.TypeOf(&c_COMPLEX64_DEFAULT): + var x complex64 + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.Engine.LogError(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + case reflect.TypeOf(&c_COMPLEX128_DEFAULT): + var x complex128 + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.Engine.LogError(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + hasAssigned = true + } // switch fieldType + // default: + // session.Engine.LogError("unsupported type in Scan: ", reflect.TypeOf(v).String()) + } // switch fieldType.Kind() + + // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value + if !hasAssigned { + data, err := value2Bytes(&rawValue) + if err == nil { + session.bytes2Value(table.Columns[key], fieldValue, data) + } else { + session.Engine.LogError(err.Error()) + } + } + } + } + return nil + } func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { @@ -1334,7 +1521,6 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st } defer rows.Close() //fmt.Println(rows) - return rows2maps(rows) } @@ -1619,7 +1805,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1631,7 +1817,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x := reflect.New(fieldType) err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1642,7 +1828,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x := reflect.New(fieldType) err := json.Unmarshal(data, x.Interface()) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(x.Elem()) @@ -1656,7 +1842,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data d := string(data) v, err := strconv.ParseBool(d) if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) + return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(v)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -1684,19 +1870,19 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 10, 64) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.SetInt(x) case reflect.Float32, reflect.Float64: x, err := strconv.ParseFloat(string(data), 64) if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) + return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.SetFloat(x) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.SetUint(x) //Currently only support Time type @@ -1713,7 +1899,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if table != nil { x, err := strconv.ParseInt(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } if x != 0 { // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch @@ -1734,7 +1920,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } } } else { - return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) + return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } case reflect.Ptr: @@ -1750,7 +1936,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data d := string(data) v, err := strconv.ParseBool(d) if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) + return fmt.Errorf("arg %v as bool: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&v)) // case "*complex64": @@ -1758,7 +1944,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x complex64 err := json.Unmarshal(data, &x) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(reflect.ValueOf(&x)) @@ -1767,7 +1953,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x complex128 err := json.Unmarshal(data, &x) if err != nil { - session.Engine.LogSQL(err) + session.Engine.LogError(err) return err } fieldValue.Set(reflect.ValueOf(&x)) @@ -1775,7 +1961,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data case reflect.TypeOf(&c_FLOAT64_DEFAULT): x, err := strconv.ParseFloat(string(data), 64) if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) + return fmt.Errorf("arg %v as float64: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*float32": @@ -1783,7 +1969,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x float32 x1, err := strconv.ParseFloat(string(data), 32) if err != nil { - return errors.New("arg " + key + " as float32: " + err.Error()) + return fmt.Errorf("arg %v as float32: %s", key, err.Error()) } x = float32(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1800,7 +1986,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*uint": @@ -1808,7 +1994,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1817,7 +2003,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint32(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1826,7 +2012,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint8(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1835,7 +2021,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } x = uint16(x1) fieldValue.Set(reflect.ValueOf(&x)) @@ -1861,7 +2047,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 10, 64) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int": @@ -1890,7 +2076,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int32": @@ -1919,7 +2105,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int32(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int8": @@ -1948,7 +2134,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int8(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) // case "*int16": @@ -1977,14 +2163,14 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x = int16(x1) } if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) + return fmt.Errorf("arg %v as int: %s", key, err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } return nil diff --git a/table.go b/table.go index aac87528..16e4b1fd 100644 --- a/table.go +++ b/table.go @@ -163,6 +163,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { if t == reflect.TypeOf(c_TIME_DEFAULT) { st = SQLType{DateTime, 0, 0} } else { + // TODO need to handle association struct st = SQLType{Text, 0, 0} } case reflect.Ptr: From 814036e2581afbb791555adf5cc706d230ac5004 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Sat, 28 Dec 2013 02:42:50 +0800 Subject: [PATCH 2/5] 1. correct use of 'sql' string clash with sql pacakage. 2. checked type conversion. --- session.go | 470 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 310 insertions(+), 160 deletions(-) diff --git a/session.go b/session.go index 9d86b1de..11164215 100644 --- a/session.go +++ b/session.go @@ -186,8 +186,8 @@ func (session *Session) Desc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " DESC" return session } @@ -197,8 +197,8 @@ func (session *Session) Asc(colNames ...string) *Session { session.Statement.OrderStr += ", " } newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" + sqlStr := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) + session.Statement.OrderStr += session.Engine.Quote(sqlStr) + " ASC" return session } @@ -392,8 +392,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b } //Execute sql -func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(sql) +func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(sqlStr) if err != nil { return nil, err } @@ -406,22 +406,22 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, return res, nil } -func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - session.Engine.LogSQL(sql) + session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(args) if session.IsAutoCommit { - return session.innerExec(sql, args...) + return session.innerExec(sqlStr, args...) } - return session.Tx.Exec(sql, args...) + return session.Tx.Exec(sqlStr, args...) } // Exec raw sql -func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { +func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { err := session.newDb() if err != nil { return nil, err @@ -431,7 +431,7 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error defer session.Close() } - return session.exec(sql, args...) + return session.exec(sqlStr, args...) } // this function create a table according a bean @@ -464,8 +464,8 @@ func (session *Session) CreateIndexes(bean interface{}) error { } sqls := session.Statement.genIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -487,8 +487,8 @@ func (session *Session) CreateUniques(bean interface{}) error { } sqls := session.Statement.genUniqueSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -497,9 +497,9 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createOneTable() error { - sql := session.Statement.genCreateTableSQL() - session.Engine.LogDebug("create table sql: [", sql, "]") - _, err := session.exec(sql) + sqlStr := session.Statement.genCreateTableSQL() + session.Engine.LogDebug("create table sql: [", sqlStr, "]") + _, err := session.exec(sqlStr) return err } @@ -536,8 +536,8 @@ func (session *Session) DropIndexes(bean interface{}) error { } sqls := session.Statement.genDelIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) + for _, sqlStr := range sqls { + _, err = session.exec(sqlStr) if err != nil { return err } @@ -567,16 +567,16 @@ func (session *Session) DropTable(bean interface{}) error { return errors.New("Unsupported type") } - sql := session.Statement.genDropSQL() - _, err = session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err = session.exec(sqlStr) return err } -func (statement *Statement) convertIdSql(sql string) string { +func (statement *Statement) convertIdSql(sqlStr string) string { if statement.RefTable != nil { col := statement.RefTable.PKColumn() if col != nil { - sqls := splitNNoCase(sql, "from", 2) + sqls := splitNNoCase(sqlStr, "from", 2) if len(sqls) != 2 { return "" } @@ -588,14 +588,14 @@ func (statement *Statement) convertIdSql(sql string) string { return "" } -func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { +func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return false, ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return false, ErrCacheFailed } @@ -667,19 +667,19 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface return false, nil } -func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { +func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" || - indexNoCase(sql, "having") != -1 || - indexNoCase(sql, "group by") != -1 { + indexNoCase(sqlStr, "having") != -1 || + indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -867,41 +867,67 @@ func (session *Session) Get(bean interface{}) (bool, error) { } session.Statement.Limit(1) - var sql string + var sqlStr string var args []interface{} session.Statement.RefTable = session.Engine.autoMap(bean) if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) + sqlStr, args = session.Statement.genGetSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { - has, err := session.cacheGet(bean, sql, args...) + has, err := session.cacheGet(bean, sqlStr, args...) if err != ErrCacheFailed { return has, err } } - resultsSlice, err := session.query(sql, args...) + var rawRows *sql.Rows + session.queryPreprocess(sqlStr, args...) + if session.IsAutoCommit { + stmt, err := session.Db.Prepare(sqlStr) + if err != nil { + return false, err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) + } else { + rawRows, err = session.Tx.Query(sqlStr, args...) + } if err != nil { return false, err } - if len(resultsSlice) < 1 { + defer rawRows.Close() + + if rawRows.Next() { + if fields, err := rawRows.Columns(); err == nil { + err = session.row2Bean(rawRows, fields, len(fields), bean) + } + return true, err + } else { return false, nil } - err = session.scanMapIntoStruct(bean, resultsSlice[0]) - if err != nil { - return true, err - } - if len(resultsSlice) == 1 { - return true, nil - } else { - return true, errors.New("More than one record") - } + // resultsSlice, err := session.query(sqlStr, args...) + // if err != nil { + // return false, err + // } + // if len(resultsSlice) < 1 { + // return false, nil + // } + + // err = session.scanMapIntoStruct(bean, resultsSlice[0]) + // if err != nil { + // return true, err + // } + // if len(resultsSlice) == 1 { + // return true, nil + // } else { + // return true, errors.New("More than one record") + // } } // Count counts the records. bean's non-empty fields @@ -917,16 +943,16 @@ func (session *Session) Count(bean interface{}) (int64, error) { defer session.Close() } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - sql, args = session.Statement.genCountSql(bean) + sqlStr, args = session.Statement.genCountSql(bean) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } - resultsSlice, err := session.query(sql, args...) + resultsSlice, err := session.query(sqlStr, args...) if err != nil { return 0, err } @@ -987,7 +1013,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.BeanArgs = args } - var sql string + var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr @@ -997,46 +1023,94 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.attachInSql() - sql = session.Statement.genSelectSql(columnStr) + sqlStr = session.Statement.genSelectSql(columnStr) args = append(session.Statement.Params, session.Statement.BeanArgs...) } else { - sql = session.Statement.RawSQL + sqlStr = session.Statement.RawSQL args = session.Statement.RawParams } if table.Cacher != nil && session.Statement.UseCache && !session.Statement.IsDistinct { - err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) + err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) if err != ErrCacheFailed { return err } session.Engine.LogWarn("Cache Find Failed") } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return err - } + if sliceValue.Kind() != reflect.Map { + var rawRows *sql.Rows - for i, results := range resultsSlice { - var newValue reflect.Value - if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) + session.queryPreprocess(sqlStr, args...) + // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) + // if err != nil { + // return err + // } + // if stmt != nil { + // defer stmt.Close() + // } + // defer rawRows.Close() + + if session.IsAutoCommit { + stmt, err := session.Db.Prepare(sqlStr) + if err != nil { + return err + } + defer stmt.Close() + rawRows, err = stmt.Query(args...) } else { - newValue = reflect.New(sliceElementType) + rawRows, err = session.Tx.Query(sqlStr, args...) } - err := session.scanMapIntoStruct(newValue.Interface(), results) if err != nil { return err } - if sliceValue.Kind() == reflect.Slice { + defer rawRows.Close() + + fields, err := rawRows.Columns() + if err != nil { + return err + } + + fieldsCount := len(fields) + + for rawRows.Next() { + var newValue reflect.Value if sliceElementType.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + newValue = reflect.New(sliceElementType.Elem()) } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + newValue = reflect.New(sliceElementType) + } + err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface()) + if err != nil { + return err + } + if sliceValue.Kind() == reflect.Slice { + if sliceElementType.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + } + } + } else { + resultsSlice, err := session.query(sqlStr, args...) + if err != nil { + return err + } + + for i, results := range resultsSlice { + var newValue reflect.Value + if sliceElementType.Kind() == reflect.Ptr { + newValue = reflect.New(sliceElementType.Elem()) + } else { + newValue = reflect.New(sliceElementType) + } + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err } - } else if sliceValue.Kind() == reflect.Map { var key int64 if table.PrimaryKey != "" { x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) @@ -1057,6 +1131,20 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } +func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { + var err error + if session.IsAutoCommit { + *rawStmt, err = session.Db.Prepare(sqlStr) + if err != nil { + return err + } + *rawRows, err = (*rawStmt).Query(args...) + } else { + *rawRows, err = session.Tx.Query(sqlStr, args...) + } + return err +} + // Test if database is ok func (session *Session) Ping() error { err := session.newDb() @@ -1080,8 +1168,8 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1094,8 +1182,8 @@ func (session *Session) isTableExist(tableName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.TableCheckSql(tableName) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1114,8 +1202,8 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo } else { idx = indexName(tableName, idxName) } - sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) - results, err := session.query(sql, args...) + sqlStr, args := session.Engine.dialect.IndexCheckSql(tableName, idx) + results, err := session.query(sqlStr, args...) return len(results) > 0, err } @@ -1149,8 +1237,8 @@ func (session *Session) addColumn(colName string) error { } //fmt.Println(session.Statement.RefTable) col := session.Statement.RefTable.Columns[colName] - sql, args := session.Statement.genAddColumnStr(col) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddColumnStr(col) + _, err = session.exec(sqlStr, args...) return err } @@ -1165,8 +1253,8 @@ func (session *Session) addIndex(tableName, idxName string) error { } //fmt.Println(idxName) cols := session.Statement.RefTable.Indexes[idxName].Cols - sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1181,8 +1269,8 @@ func (session *Session) addUnique(tableName, uqeName string) error { } //fmt.Println(uqeName, session.Statement.RefTable.Uniques) cols := session.Statement.RefTable.Indexes[uqeName].Cols - sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) - _, err = session.exec(sql, args...) + sqlStr, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) + _, err = session.exec(sqlStr, args...) return err } @@ -1200,8 +1288,8 @@ func (session *Session) dropAll() error { for _, table := range session.Engine.Tables { session.Statement.Init() session.Statement.RefTable = table - sql := session.Statement.genDropSQL() - _, err := session.exec(sql) + sqlStr := session.Statement.genDropSQL() + _, err := session.exec(sqlStr) if err != nil { return err } @@ -1306,7 +1394,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in continue } - aa := reflect.TypeOf(rawValue.Interface()) + rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) fieldType := fieldValue.Type() @@ -1318,7 +1406,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in switch fieldType.Kind() { case reflect.Complex64, reflect.Complex128: - if aa.Kind() == reflect.String { + if rawValueType.Kind() == reflect.String { hasAssigned = true x := reflect.New(fieldType) err := json.Unmarshal([]byte(vv.String()), x.Interface()) @@ -1329,38 +1417,40 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in fieldValue.Set(x.Elem()) } case reflect.Slice, reflect.Array: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Slice, reflect.Array: - switch aa.Elem().Kind() { + switch rawValueType.Elem().Kind() { case reflect.Uint8: - hasAssigned = true - fieldValue.Set(rawValue) + if fieldType.Elem().Kind() == reflect.Uint8 { + hasAssigned = true + fieldValue.Set(vv) + } } } case reflect.String: - if aa.Kind() == reflect.String { + if rawValueType.Kind() == reflect.String { hasAssigned = true fieldValue.SetString(vv.String()) } case reflect.Bool: - if aa.Kind() == reflect.Bool { + if rawValueType.Kind() == reflect.Bool { hasAssigned = true fieldValue.SetBool(vv.Bool()) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: hasAssigned = true fieldValue.SetInt(vv.Int()) } case reflect.Float32, reflect.Float64: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Float32, reflect.Float64: hasAssigned = true fieldValue.SetFloat(vv.Float()) } case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - switch aa.Kind() { + switch rawValueType.Kind() { case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: hasAssigned = true fieldValue.SetUint(vv.Uint()) @@ -1368,7 +1458,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in //Currently only support Time type case reflect.Struct: if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { - if aa == reflect.TypeOf(c_TIME_DEFAULT) { + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { hasAssigned = true fieldValue.Set(rawValue) } @@ -1407,46 +1497,95 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in //typeStr := fieldType.String() switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly - case reflect.TypeOf(&c_EMPTY_STRING), reflect.TypeOf(&c_BOOL_DEFAULT), reflect.TypeOf(&c_TIME_DEFAULT), - reflect.TypeOf(&c_FLOAT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT), reflect.TypeOf(&c_INT64_DEFAULT): - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&rawValue)) + case reflect.TypeOf(&c_EMPTY_STRING): + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_BOOL_DEFAULT): + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_TIME_DEFAULT): + if rawValueType == reflect.TypeOf(c_TIME_DEFAULT) { + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&rawValue)) + } + case reflect.TypeOf(&c_FLOAT64_DEFAULT): + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_UINT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + var x uint64 = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case reflect.TypeOf(&c_INT64_DEFAULT): + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_FLOAT32_DEFAULT): - var x float32 = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Float64 { + var x float32 = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT_DEFAULT): - var x int = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT32_DEFAULT): - var x int32 = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int32 = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT8_DEFAULT): - var x int8 = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int8 = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_INT16_DEFAULT): - var x int16 = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x int16 = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT_DEFAULT): - var x uint = uint(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT32_DEFAULT): - var x uint32 = uint32(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint32 = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT8_DEFAULT): - var x uint8 = uint8(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint8 = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_UINT16_DEFAULT): - var x uint16 = uint16(vv.Uint()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) + if rawValueType.Kind() == reflect.Int64 { + var x uint16 = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } case reflect.TypeOf(&c_COMPLEX64_DEFAULT): var x complex64 err := json.Unmarshal([]byte(vv.String()), &x) @@ -1485,22 +1624,33 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in } -func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) queryPreprocess(sqlStr string, paramStr ...interface{}) { for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - session.Engine.LogSQL(sql) + session.Engine.LogSQL(sqlStr) + session.Engine.LogSQL(paramStr) +} + +func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + // !nashtsai! TODO calling session.queryPreprocess with cause error + // session.queryPreprocess(sqlStr, paramStr...) + for _, filter := range session.Engine.Filters { + sqlStr = filter.Do(sqlStr, session) + } + + session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(paramStr) if session.IsAutoCommit { - return query(session.Db, sql, paramStr...) + return query(session.Db, sqlStr, paramStr...) } - return txQuery(session.Tx, sql, paramStr...) + return txQuery(session.Tx, sqlStr, paramStr...) } -func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - rows, err := tx.Query(sql, params...) +func txQuery(tx *sql.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + rows, err := tx.Query(sqlStr, params...) if err != nil { return nil, err } @@ -1509,8 +1659,8 @@ func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[ return rows2maps(rows) } -func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - s, err := db.Prepare(sql) +func query(db *sql.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + s, err := db.Prepare(sqlStr) if err != nil { return nil, err } @@ -1525,7 +1675,7 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st } // Exec a raw sql and return records as []map[string][]byte -func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { +func (session *Session) Query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { err = session.newDb() if err != nil { return nil, err @@ -1535,7 +1685,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice defer session.Close() } - return session.query(sql, paramStr...) + return session.query(sqlStr, paramStr...) } // insert one or more beans @@ -2310,7 +2460,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces = colPlaces[0 : len(colPlaces)-2] - sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", + sqlStr := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", session.Engine.QuoteStr(), session.Statement.TableName(), session.Engine.QuoteStr(), @@ -2351,7 +2501,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } else { @@ -2395,8 +2545,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } else { - sql = sql + " RETURNING (id)" - res, err := session.query(sql, args...) + sqlStr = sqlStr + " RETURNING (id)" + res, err := session.query(sqlStr, args...) if err != nil { return 0, err } else { @@ -2458,11 +2608,11 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (statement *Statement) convertUpdateSql(sql string) (string, string) { +func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) { if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { return "", "" } - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", @@ -2505,12 +2655,12 @@ func (session *Session) cacheInsert(tables ...string) error { return nil } -func (session *Session) cacheUpdate(sql string, args ...interface{}) error { +func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return ErrCacheFailed } - oldhead, newsql := session.Statement.convertUpdateSql(sql) + oldhead, newsql := session.Statement.convertUpdateSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2521,7 +2671,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { var nStart int if len(args) > 0 { - if strings.Index(sql, "?") > -1 { + if strings.Index(sqlStr, "?") > -1 { nStart = strings.Count(oldhead, "?") } else { // only for pq, TODO: if any other databse? @@ -2562,7 +2712,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := splitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) == 0 || len(sqls) > 2 { return ErrCacheFailed } @@ -2701,7 +2851,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - var sql, inSql string + var sqlStr, inSql string var inArgs []interface{} if table.Version != "" && session.Statement.checkVersion { if condition != "" { @@ -2719,7 +2869,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v, %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", @@ -2739,7 +2889,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - sql = fmt.Sprintf("UPDATE %v SET %v %v", + sqlStr = fmt.Sprintf("UPDATE %v SET %v %v", session.Engine.Quote(session.Statement.TableName()), strings.Join(colNames, ", "), condition) @@ -2749,13 +2899,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, inArgs...) args = append(args, condiArgs...) - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } if table.Cacher != nil && session.Statement.UseCache { - //session.cacheUpdate(sql, args...) + //session.cacheUpdate(sqlStr, args...) table.Cacher.ClearIds(session.Statement.TableName()) table.Cacher.ClearBeans(session.Statement.TableName()) } @@ -2792,16 +2942,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return res.RowsAffected() } -func (session *Session) cacheDelete(sql string, args ...interface{}) error { +func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { return ErrCacheFailed } for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) + sqlStr = filter.Do(sqlStr, session) } - newsql := session.Statement.convertIdSql(sql) + newsql := session.Statement.convertIdSql(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -2893,16 +3043,16 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return 0, ErrNeedDeletedCond } - sql := fmt.Sprintf("DELETE FROM %v WHERE %v", + sqlStr := fmt.Sprintf("DELETE FROM %v WHERE %v", session.Engine.Quote(session.Statement.TableName()), condition) args = append(session.Statement.Params, args...) if table.Cacher != nil && session.Statement.UseCache { - session.cacheDelete(sql, args...) + session.cacheDelete(sqlStr, args...) } - res, err := session.exec(sql, args...) + res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } From 8f0aba838f56fe7ead42aa2c694abc435073aca3 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Sun, 29 Dec 2013 17:37:43 +0800 Subject: [PATCH 3/5] use session.queryPreprocess implementations instead of calling session.queryPreprocess to void sql error --- session.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/session.go b/session.go index 11164215..36a7e193 100644 --- a/session.go +++ b/session.go @@ -886,7 +886,16 @@ func (session *Session) Get(bean interface{}) (bool, error) { } var rawRows *sql.Rows - session.queryPreprocess(sqlStr, args...) + // !nashtsai! TODO calling session.queryPreprocess with cause error + // session.queryPreprocess(sqlStr, args...) + + for _, filter := range session.Engine.Filters { + sqlStr = filter.Do(sqlStr, session) + } + + session.Engine.LogSQL(sqlStr) + session.Engine.LogSQL(args) + if session.IsAutoCommit { stmt, err := session.Db.Prepare(sqlStr) if err != nil { @@ -1043,7 +1052,15 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceValue.Kind() != reflect.Map { var rawRows *sql.Rows - session.queryPreprocess(sqlStr, args...) + // !nashtsai! TODO calling session.queryPreprocess with cause error + // session.queryPreprocess(sqlStr, args...) + for _, filter := range session.Engine.Filters { + sqlStr = filter.Do(sqlStr, session) + } + + session.Engine.LogSQL(sqlStr) + session.Engine.LogSQL(args) + // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) // if err != nil { // return err From 5fa8a7271db0aa82a80fabbc7fd344bd9d15bfc0 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Mon, 30 Dec 2013 02:32:26 +0800 Subject: [PATCH 4/5] fixed session.queryPreprocess usage --- session.go | 36 ++++++------------------------------ 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/session.go b/session.go index 36a7e193..d22f9ed8 100644 --- a/session.go +++ b/session.go @@ -886,16 +886,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { } var rawRows *sql.Rows - // !nashtsai! TODO calling session.queryPreprocess with cause error - // session.queryPreprocess(sqlStr, args...) - - for _, filter := range session.Engine.Filters { - sqlStr = filter.Do(sqlStr, session) - } - - session.Engine.LogSQL(sqlStr) - session.Engine.LogSQL(args) - + session.queryPreprocess(&sqlStr, args...) if session.IsAutoCommit { stmt, err := session.Db.Prepare(sqlStr) if err != nil { @@ -1052,15 +1043,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if sliceValue.Kind() != reflect.Map { var rawRows *sql.Rows - // !nashtsai! TODO calling session.queryPreprocess with cause error - // session.queryPreprocess(sqlStr, args...) - for _, filter := range session.Engine.Filters { - sqlStr = filter.Do(sqlStr, session) - } - - session.Engine.LogSQL(sqlStr) - session.Engine.LogSQL(args) - + session.queryPreprocess(&sqlStr, args...) // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) // if err != nil { // return err @@ -1641,24 +1624,17 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in } -func (session *Session) queryPreprocess(sqlStr string, paramStr ...interface{}) { +func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { for _, filter := range session.Engine.Filters { - sqlStr = filter.Do(sqlStr, session) + *sqlStr = filter.Do(*sqlStr, session) } - session.Engine.LogSQL(sqlStr) + session.Engine.LogSQL(*sqlStr) session.Engine.LogSQL(paramStr) } func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - // !nashtsai! TODO calling session.queryPreprocess with cause error - // session.queryPreprocess(sqlStr, paramStr...) - for _, filter := range session.Engine.Filters { - sqlStr = filter.Do(sqlStr, session) - } - - session.Engine.LogSQL(sqlStr) - session.Engine.LogSQL(paramStr) + session.queryPreprocess(&sqlStr, paramStr...) if session.IsAutoCommit { return query(session.Db, sqlStr, paramStr...) From e7379dc7d98490fa245244a3dd7f24c427c81f77 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Mon, 30 Dec 2013 17:40:59 +0800 Subject: [PATCH 5/5] fixed cache usage error when using session.Find() --- session.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/session.go b/session.go index d22f9ed8..34f51648 100644 --- a/session.go +++ b/session.go @@ -1037,11 +1037,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if err != ErrCacheFailed { return err } + err = nil // !nashtsai! reset err to nil for ErrCacheFailed session.Engine.LogWarn("Cache Find Failed") } if sliceValue.Kind() != reflect.Map { var rawRows *sql.Rows + var stmt *sql.Stmt session.queryPreprocess(&sqlStr, args...) // err = session.queryRows(&stmt, &rawRows, sqlStr, args...) @@ -1054,7 +1056,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) // defer rawRows.Close() if session.IsAutoCommit { - stmt, err := session.Db.Prepare(sqlStr) + stmt, err = session.Db.Prepare(sqlStr) if err != nil { return err }