Merge pull request #270 from winxxp/master

add support sql.NullString sql.NullInt64 ...
This commit is contained in:
Lunny Xiao 2015-07-27 09:44:30 +08:00
commit a4765bce78
2 changed files with 135 additions and 106 deletions

View File

@ -6,6 +6,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -637,7 +638,7 @@ func (session *Session) canCache() bool {
if session.Statement.RefTable == nil || if session.Statement.RefTable == nil ||
session.Statement.JoinStr != "" || session.Statement.JoinStr != "" ||
session.Statement.RawSQL != "" || session.Statement.RawSQL != "" ||
session.Tx != nil || session.Tx != nil ||
len(session.Statement.selectStr) > 0 { len(session.Statement.selectStr) > 0 {
return false return false
} }
@ -744,7 +745,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) {
if !session.canCache() || if !session.canCache() ||
indexNoCase(sqlStr, "having") != -1 || indexNoCase(sqlStr, "having") != -1 ||
indexNoCase(sqlStr, "group by") != -1 { indexNoCase(sqlStr, "group by") != -1 {
return ErrCacheFailed return ErrCacheFailed
@ -1187,7 +1188,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var addedTableName = (len(session.Statement.JoinStr) > 0) var addedTableName = (len(session.Statement.JoinStr) > 0)
colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true,
false, true, session.Statement.allUseBool, session.Statement.useAllCols, false, true, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.unscoped, session.Statement.mustColumnMap, session.Statement.unscoped, session.Statement.mustColumnMap,
session.Statement.TableName(), addedTableName) session.Statement.TableName(), addedTableName)
session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.ConditionStr = strings.Join(colNames, " AND ")
session.Statement.BeanArgs = args session.Statement.BeanArgs = args
@ -1314,7 +1315,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
table := session.Engine.autoMapType(dataStruct) table := session.Engine.autoMapType(dataStruct)
return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc) return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc)
} else { } else {
resultsSlice, err := session.query(sqlStr, args...) resultsSlice, err := session.query(sqlStr, args...)
@ -1755,6 +1755,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
vv = reflect.ValueOf(t) vv = reflect.ValueOf(t)
fieldValue.Set(vv) fieldValue.Set(vv)
} }
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
fmt.Println("sql.Sanner error:", err.Error())
session.Engine.LogError("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
table := session.Engine.autoMapType(*fieldValue) table := session.Engine.autoMapType(*fieldValue)
if table != nil { if table != nil {
@ -1762,6 +1770,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
panic("unsupported composited primary key cascade") panic("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(core.PK, len(table.PrimaryKeys))
switch rawValueType.Kind() { switch rawValueType.Kind() {
case reflect.Int64: case reflect.Int64:
pk[0] = vv.Int() pk[0] = vv.Int()
@ -2416,108 +2425,115 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
fieldValue.SetUint(x) fieldValue.SetUint(x)
//Currently only support Time type //Currently only support Time type
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { // !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
x, err := session.byte2Time(col, data) if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
if err != nil { if err := nulVal.Scan(data); err != nil {
return err return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error())
} }
v = x } else {
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) if fieldType.ConvertibleTo(core.TimeType) {
} else if session.Statement.UseCascade { x, err := session.byte2Time(col, data)
table := session.Engine.autoMapType(*fieldValue) if err != nil {
if table != nil { return err
if len(table.PrimaryKeys) > 1 {
panic("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) v = x
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
switch rawValueType.Kind() { } else if session.Statement.UseCascade {
case reflect.Int64: table := session.Engine.autoMapType(*fieldValue)
x, err := strconv.ParseInt(string(data), 10, 64) if table != nil {
if err != nil { if len(table.PrimaryKeys) > 1 {
return fmt.Errorf("arg %v as int: %s", key, err.Error()) panic("unsupported composited primary key cascade")
} }
pk[0] = x var pk = make(core.PK, len(table.PrimaryKeys))
case reflect.Int: rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
x, err := strconv.ParseInt(string(data), 10, 64) switch rawValueType.Kind() {
if err != nil { case reflect.Int64:
return fmt.Errorf("arg %v as int: %s", key, err.Error()) x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = x
case reflect.Int:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int(x)
case reflect.Int32:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int32(x)
case reflect.Int16:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int16(x)
case reflect.Int8:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int8(x)
case reflect.Uint64:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = x
case reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint(x)
case reflect.Uint32:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint32(x)
case reflect.Uint16:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint16(x)
case reflect.Uint8:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint8(x)
case reflect.String:
pk[0] = string(data)
default:
panic("unsupported primary key type cascade")
} }
pk[0] = int(x)
case reflect.Int32:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int32(x)
case reflect.Int16:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int16(x)
case reflect.Int8:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int8(x)
case reflect.Uint64:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = x
case reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint(x)
case reflect.Uint32:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint32(x)
case reflect.Uint16:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint16(x)
case reflect.Uint8:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint8(x)
case reflect.String:
pk[0] = string(data)
default:
panic("unsupported primary key type cascade")
}
if !isPKZero(pk) { if !isPKZero(pk) {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily // property to be fetched lazily
structInter := reflect.New(fieldValue.Type()) structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession() newsession := session.Engine.NewSession()
defer newsession.Close() defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return err return err
} }
if has { if has {
v = structInter.Elem().Interface() v = structInter.Elem().Interface()
fieldValue.Set(reflect.ValueOf(v)) fieldValue.Set(reflect.ValueOf(v))
} else { } else {
return errors.New("cascade obj is not exist!") return errors.New("cascade obj is not exist!")
}
} }
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
} }
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
} }
} }
case reflect.Ptr: case reflect.Ptr:
@ -2931,6 +2947,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
tf := session.Engine.FormatTime(col.SQLType.Name, t) tf := session.Engine.FormatTime(col.SQLType.Name, t)
return tf, nil return tf, nil
} }
if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
if len(fieldTable.PrimaryKeys) == 1 { if len(fieldTable.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
@ -2939,6 +2956,11 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return 0, fmt.Errorf("no primary key for col %v", col.Name) return 0, fmt.Errorf("no primary key for col %v", col.Name)
} }
} else { } else {
// !<winxxp>! 增加支持driver.Valuer接口的结构如sql.NullString
if v, ok := fieldValue.Interface().(driver.Valuer); ok {
return v.Value()
}
return 0, fmt.Errorf("Unsupported type %v", fieldValue.Type()) return 0, fmt.Errorf("Unsupported type %v", fieldValue.Type())
} }
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
@ -2998,12 +3020,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
processor.BeforeInsert() processor.BeforeInsert()
} }
// -- // --
colNames, args, err := genCols(table, session, bean, false, false) colNames, args, err := genCols(table, session, bean, false, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
// insert expr columns, override if exists // insert expr columns, override if exists
exprColumns := session.Statement.getExpr() exprColumns := session.Statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns)) exprColVals := make([]string, 0, len(exprColumns))
@ -3414,7 +3434,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if session.Statement.ColumnStr == "" { if session.Statement.ColumnStr == "" {
colNames, args = buildUpdates(session.Engine, table, bean, false, false, colNames, args = buildUpdates(session.Engine, table, bean, false, false,
false, false, session.Statement.allUseBool, session.Statement.useAllCols, false, false, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.mustColumnMap, session.Statement.nullableMap, session.Statement.mustColumnMap, session.Statement.nullableMap,
session.Statement.columnMap, true) session.Statement.columnMap, true)
} else { } else {
colNames, args, err = genCols(table, session, bean, true, true) colNames, args, err = genCols(table, session, bean, true, true)
@ -3696,7 +3716,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := buildConditions(session.Engine, table, bean, true, true, colNames, args := buildConditions(session.Engine, table, bean, true, true,
false, true, session.Statement.allUseBool, session.Statement.useAllCols, false, true, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.unscoped, session.Statement.mustColumnMap, session.Statement.unscoped, session.Statement.mustColumnMap,
session.Statement.TableName(), false) session.Statement.TableName(), false)
var condition = "" var condition = ""

View File

@ -5,6 +5,7 @@
package xorm package xorm
import ( import (
"database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -49,7 +50,7 @@ type Statement struct {
GroupByStr string GroupByStr string
HavingStr string HavingStr string
ColumnStr string ColumnStr string
selectStr string selectStr string
columnMap map[string]bool columnMap map[string]bool
useAllCols bool useAllCols bool
OmitStr string OmitStr string
@ -219,6 +220,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
requiredField := useAllCols requiredField := useAllCols
includeNil := useAllCols includeNil := useAllCols
lColName := strings.ToLower(col.Name) lColName := strings.ToLower(col.Name)
if b, ok := mustColumnMap[lColName]; ok { if b, ok := mustColumnMap[lColName]; ok {
if b { if b {
requiredField = true requiredField = true
@ -320,6 +322,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
continue continue
} }
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
} else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = nulType.Value()
} else { } else {
engine.autoMapType(fieldValue) engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {
@ -416,7 +420,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
var colName string var colName string
if addedTableName { if addedTableName {
colName = engine.Quote(tableName)+"."+engine.Quote(col.Name) colName = engine.Quote(tableName) + "." + engine.Quote(col.Name)
} else { } else {
colName = engine.Quote(col.Name) colName = engine.Quote(col.Name)
} }
@ -428,7 +432,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
} }
if col.IsDeleted && !unscoped { // tag "deleted" is enabled if col.IsDeleted && !unscoped { // tag "deleted" is enabled
colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')",
colName, colName)) colName, colName))
} }
@ -509,6 +513,11 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
continue continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else { } else {
engine.autoMapType(fieldValue) engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {