diff --git a/engine.go b/engine.go index 26d84d25..6e5f14fe 100644 --- a/engine.go +++ b/engine.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/go-xorm/builder" "github.com/go-xorm/core" ) @@ -1562,3 +1563,11 @@ func (engine *Engine) Unscoped() *Session { session.IsAutoClose = true return session.Unscoped() } + +// CondDeleted returns the conditions whether a record is soft deleted. +func (engine *Engine) CondDeleted(colName string) builder.Cond { + if engine.dialect.DBType() == core.MSSQL { + return builder.IsNull{colName} + } + return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1}) +} diff --git a/engine_cond.go b/engine_cond.go new file mode 100644 index 00000000..6c8e3879 --- /dev/null +++ b/engine_cond.go @@ -0,0 +1,230 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/go-xorm/builder" + "github.com/go-xorm/core" +) + +func (engine *Engine) buildConds(table *core.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + var conds []builder.Cond + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + + if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { + continue + } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = engine.Quote(nm) + "." + engine.Quote(col.Name) + } else { + colName = engine.Quote(col.Name) + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + engine.logger.Error(err) + continue + } + + if col.IsDeleted && !unscoped { // tag "deleted" is enabled + conds = append(conds, engine.CondDeleted(colName)) + } + + fieldValue := *fieldValuePtr + if fieldValue.Interface() == nil { + continue + } + + fieldType := reflect.TypeOf(fieldValue.Interface()) + requiredField := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + conds = append(conds, builder.Eq{colName: nil}) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = engine.formatColTime(col, t) + } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { + continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil { + continue + } + } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = bytes + } + } else { + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !isZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } else { + val = fieldValue.Interface() + } + } + } + case reflect.Array: + continue + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + conds = append(conds, builder.Eq{colName: val}) + } + + return builder.And(conds...), nil +} diff --git a/session.go b/session.go index afcab3c9..de4a1a04 100644 --- a/session.go +++ b/session.go @@ -567,7 +567,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } } else { - panic(fmt.Sprintf("rawValueType is %v, value is %v", rawValueType, vv.Interface())) + return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) } } } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { @@ -607,7 +607,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i hasAssigned = true if len(table.PrimaryKeys) != 1 { - panic("unsupported non or composited primary key cascade") + return nil, errors.New("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) pk[0], err = asKind(vv, rawValueType) diff --git a/session_convert.go b/session_convert.go index df44ace7..931d1dc0 100644 --- a/session_convert.go +++ b/session_convert.go @@ -28,8 +28,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti parseLoc = col.TimeZone } - if sdata == "0000-00-00 00:00:00" || - sdata == "0001-01-01 00:00:00" { + if sdata == zeroTime0 || sdata == zeroTime1 { } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column // time stamp sd, err := strconv.ParseInt(sdata, 10, 64) @@ -213,8 +212,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, // TODO: current only support 1 primary key if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + return errors.New("unsupported composited primary key cascade") } + var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) @@ -496,8 +496,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + return errors.New("unsupported composited primary key cascade") } + var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) diff --git a/session_find.go b/session_find.go index 9b8b31ef..be64878a 100644 --- a/session_find.go +++ b/session_find.go @@ -66,7 +66,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var err error autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) if err != nil { - panic(err) + return err } } else { // !oinume! Add "