From d716685a9ebe976af67b740e5f8b01156a41979f Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 20 Jul 2015 09:25:07 +0800 Subject: [PATCH 1/4] add support sql.NullString sql.NullInt64 ... --- session.go | 226 +++++++++++++++++++++++++++++------------------------ 1 file changed, 123 insertions(+), 103 deletions(-) diff --git a/session.go b/session.go index 79dbe3ae..a1be49e5 100644 --- a/session.go +++ b/session.go @@ -6,6 +6,7 @@ package xorm import ( "database/sql" + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -637,7 +638,7 @@ func (session *Session) canCache() bool { if session.Statement.RefTable == nil || session.Statement.JoinStr != "" || session.Statement.RawSQL != "" || - session.Tx != nil || + session.Tx != nil || len(session.Statement.selectStr) > 0 { return false } @@ -744,7 +745,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { - if !session.canCache() || + if !session.canCache() || indexNoCase(sqlStr, "having") != -1 || indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed @@ -1187,7 +1188,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var addedTableName = (len(session.Statement.JoinStr) > 0) colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap, + session.Statement.unscoped, session.Statement.mustColumnMap, session.Statement.TableName(), addedTableName) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args @@ -1314,7 +1315,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } table := session.Engine.autoMapType(dataStruct) - return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc) } else { resultsSlice, err := session.query(sqlStr, args...) @@ -1755,6 +1755,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount vv = reflect.ValueOf(t) fieldValue.Set(vv) } + } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + // !! 增加支持sql.Scanner接口的结构,如sql.NullString + hasAssigned = true + if err := nulVal.Scan(vv.Interface()); err != nil { + fmt.Println("sql.Sanner error:", err.Error()) + session.Engine.LogError("sql.Sanner error:", err.Error()) + hasAssigned = false + } } else if session.Statement.UseCascade { table := session.Engine.autoMapType(*fieldValue) if table != nil { @@ -1762,6 +1770,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount panic("unsupported composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) + switch rawValueType.Kind() { case reflect.Int64: pk[0] = vv.Int() @@ -2416,108 +2425,115 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, fieldValue.SetUint(x) //Currently only support Time type case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - x, err := session.byte2Time(col, data) - if err != nil { - return err + // !! 增加支持sql.Scanner接口的结构,如sql.NullString + if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + if err := nulVal.Scan(data); err != nil { + return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) } - v = x - fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(*fieldValue) - if table != nil { - if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + } else { + if fieldType.ConvertibleTo(core.TimeType) { + x, err := session.byte2Time(col, data) + if err != nil { + return err } - var pk = make(core.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - switch rawValueType.Kind() { - case reflect.Int64: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) + v = x + fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) + } else if session.Statement.UseCascade { + table := session.Engine.autoMapType(*fieldValue) + if table != nil { + if len(table.PrimaryKeys) > 1 { + panic("unsupported composited primary key cascade") } - pk[0] = x - case reflect.Int: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) + var pk = make(core.PK, len(table.PrimaryKeys)) + rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + switch rawValueType.Kind() { + case reflect.Int64: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = x + case reflect.Int: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int(x) + case reflect.Int32: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int32(x) + case reflect.Int16: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int16(x) + case reflect.Int8: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int8(x) + case reflect.Uint64: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = x + case reflect.Uint: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint(x) + case reflect.Uint32: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint32(x) + case reflect.Uint16: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint16(x) + case reflect.Uint8: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint8(x) + case reflect.String: + pk[0] = string(data) + default: + panic("unsupported primary key type cascade") } - pk[0] = int(x) - case reflect.Int32: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = int32(x) - case reflect.Int16: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = int16(x) - case reflect.Int8: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = int8(x) - case reflect.Uint64: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = x - case reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint(x) - case reflect.Uint32: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint32(x) - case reflect.Uint16: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint16(x) - case reflect.Uint8: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint8(x) - case reflect.String: - pk[0] = string(data) - default: - panic("unsupported primary key type cascade") - } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist!") + if !isPKZero(pk) { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v = structInter.Elem().Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } } + } else { + return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } case reflect.Ptr: @@ -2931,6 +2947,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val tf := session.Engine.FormatTime(col.SQLType.Name, t) return tf, nil } + if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { if len(fieldTable.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) @@ -2939,6 +2956,11 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return 0, fmt.Errorf("no primary key for col %v", col.Name) } } else { + // !! 增加支持driver.Valuer接口的结构,如sql.NullString + if v, ok := fieldValue.Interface().(driver.Valuer); ok { + return v.Value() + } + return 0, fmt.Errorf("Unsupported type %v", fieldValue.Type()) } case reflect.Complex64, reflect.Complex128: @@ -2998,12 +3020,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } // -- - colNames, args, err := genCols(table, session, bean, false, false) if err != nil { return 0, err } - // insert expr columns, override if exists exprColumns := session.Statement.getExpr() exprColVals := make([]string, 0, len(exprColumns)) @@ -3414,7 +3434,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildUpdates(session.Engine, table, bean, false, false, false, false, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, session.Statement.nullableMap, + session.Statement.mustColumnMap, session.Statement.nullableMap, session.Statement.columnMap, true) } else { colNames, args, err = genCols(table, session, bean, true, true) @@ -3696,7 +3716,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap, + session.Statement.unscoped, session.Statement.mustColumnMap, session.Statement.TableName(), false) var condition = "" From 916367d81e4383436220bc89496a20556298bc9e Mon Sep 17 00:00:00 2001 From: haolei Date: Mon, 20 Jul 2015 11:23:29 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=94=AF=E6=8C=81sql.Nul?= =?UTF-8?q?lString...Iterate=E6=96=B9=E6=B3=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- statement.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/statement.go b/statement.go index 0a8de1ee..d5f6d072 100644 --- a/statement.go +++ b/statement.go @@ -5,6 +5,7 @@ package xorm import ( + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -49,7 +50,7 @@ type Statement struct { GroupByStr string HavingStr string ColumnStr string - selectStr string + selectStr string columnMap map[string]bool useAllCols bool OmitStr string @@ -416,7 +417,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, var colName string if addedTableName { - colName = engine.Quote(tableName)+"."+engine.Quote(col.Name) + colName = engine.Quote(tableName) + "." + engine.Quote(col.Name) } else { colName = engine.Quote(col.Name) } @@ -428,7 +429,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, } if col.IsDeleted && !unscoped { // tag "deleted" is enabled - colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", + colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", colName, colName)) } @@ -509,6 +510,11 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, val = engine.FormatTime(col.SQLType.Name, t) } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil { + continue + } } else { engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { From 3189b41ec3dc91e263199ddca7a6f0e63da90f5d Mon Sep 17 00:00:00 2001 From: haolei Date: Mon, 20 Jul 2015 14:13:13 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E6=94=AF=E6=8C=81sql.Nul?= =?UTF-8?q?lString=E5=AD=97=E6=AE=B5=E6=97=B6=EF=BC=8C=E4=BD=BF=E7=94=A8It?= =?UTF-8?q?erater=E6=96=B9=E6=B3=95=E6=97=B6=E5=B4=A9=E6=BA=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- statement.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/statement.go b/statement.go index d5f6d072..56c0471f 100644 --- a/statement.go +++ b/statement.go @@ -321,6 +321,10 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, continue } val = engine.FormatTime(col.SQLType.Name, t) + } else if nulVal, ok := fieldValue.Interface().(driver.Valuer); ok { + if val, _ = nulVal.Value(); val == nil { + continue + } } else { engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { From 88a0888d6d897ffd0e033275839b6f5b10624198 Mon Sep 17 00:00:00 2001 From: haolei Date: Sun, 26 Jul 2015 00:01:06 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E6=94=AF=E6=8C=81sql.Nul?= =?UTF-8?q?lString...=E6=97=B6AllCols()=E6=97=B6=E6=B2=A1=E6=9C=89?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E7=A9=BA=E5=AD=97=E6=AE=B5=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- statement.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/statement.go b/statement.go index 56c0471f..6919b6c5 100644 --- a/statement.go +++ b/statement.go @@ -220,6 +220,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, requiredField := useAllCols includeNil := useAllCols lColName := strings.ToLower(col.Name) + if b, ok := mustColumnMap[lColName]; ok { if b { requiredField = true @@ -321,10 +322,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, continue } val = engine.FormatTime(col.SQLType.Name, t) - } else if nulVal, ok := fieldValue.Interface().(driver.Valuer); ok { - if val, _ = nulVal.Value(); val == nil { - continue - } + } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = nulType.Value() } else { engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok {