diff --git a/engine.go b/engine.go index 26d84d25..6e5f14fe 100644 --- a/engine.go +++ b/engine.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "github.com/go-xorm/builder" "github.com/go-xorm/core" ) @@ -1562,3 +1563,11 @@ func (engine *Engine) Unscoped() *Session { session.IsAutoClose = true return session.Unscoped() } + +// CondDeleted returns the conditions whether a record is soft deleted. +func (engine *Engine) CondDeleted(colName string) builder.Cond { + if engine.dialect.DBType() == core.MSSQL { + return builder.IsNull{colName} + } + return builder.IsNull{colName}.Or(builder.Eq{colName: zeroTime1}) +} diff --git a/engine_cond.go b/engine_cond.go new file mode 100644 index 00000000..6c8e3879 --- /dev/null +++ b/engine_cond.go @@ -0,0 +1,230 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/go-xorm/builder" + "github.com/go-xorm/core" +) + +func (engine *Engine) buildConds(table *core.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { + var conds []builder.Cond + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + + if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { + continue + } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = engine.Quote(nm) + "." + engine.Quote(col.Name) + } else { + colName = engine.Quote(col.Name) + } + + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + engine.logger.Error(err) + continue + } + + if col.IsDeleted && !unscoped { // tag "deleted" is enabled + conds = append(conds, engine.CondDeleted(colName)) + } + + fieldValue := *fieldValuePtr + if fieldValue.Interface() == nil { + continue + } + + fieldType := reflect.TypeOf(fieldValue.Interface()) + requiredField := useAllCols + + if b, ok := getFlagForColumn(mustColumnMap, col); ok { + if b { + requiredField = true + } else { + continue + } + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + conds = append(conds, builder.Eq{colName: nil}) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() + case reflect.Struct: + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = engine.formatColTime(col, t) + } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { + continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil { + continue + } + } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = bytes + } + } else { + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !isZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } else { + val = fieldValue.Interface() + } + } + } + case reflect.Array: + continue + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.logger.Error(err) + continue + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + conds = append(conds, builder.Eq{colName: val}) + } + + return builder.And(conds...), nil +} diff --git a/session.go b/session.go index afcab3c9..de4a1a04 100644 --- a/session.go +++ b/session.go @@ -567,7 +567,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) } } else { - panic(fmt.Sprintf("rawValueType is %v, value is %v", rawValueType, vv.Interface())) + return nil, fmt.Errorf("rawValueType is %v, value is %v", rawValueType, vv.Interface()) } } } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { @@ -607,7 +607,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i hasAssigned = true if len(table.PrimaryKeys) != 1 { - panic("unsupported non or composited primary key cascade") + return nil, errors.New("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) pk[0], err = asKind(vv, rawValueType) diff --git a/session_convert.go b/session_convert.go index df44ace7..931d1dc0 100644 --- a/session_convert.go +++ b/session_convert.go @@ -28,8 +28,7 @@ func (session *Session) str2Time(col *core.Column, data string) (outTime time.Ti parseLoc = col.TimeZone } - if sdata == "0000-00-00 00:00:00" || - sdata == "0001-01-01 00:00:00" { + if sdata == zeroTime0 || sdata == zeroTime1 { } else if !strings.ContainsAny(sdata, "- :") { // !nashtsai! has only found that mymysql driver is using this for time type column // time stamp sd, err := strconv.ParseInt(sdata, 10, 64) @@ -213,8 +212,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, // TODO: current only support 1 primary key if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + return errors.New("unsupported composited primary key cascade") } + var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) @@ -496,8 +496,9 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + return errors.New("unsupported composited primary key cascade") } + var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) diff --git a/session_find.go b/session_find.go index 9b8b31ef..be64878a 100644 --- a/session_find.go +++ b/session_find.go @@ -66,7 +66,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var err error autoCond, err = session.Statement.buildConds(table, condiBean[0], true, true, false, true, addedTableName) if err != nil { - panic(err) + return err } } else { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. @@ -80,11 +80,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } colName = session.Engine.Quote(nm) + "." + colName } - if session.Engine.dialect.DBType() == core.MSSQL { - autoCond = builder.IsNull{colName} - } else { - autoCond = builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"}) - } + + autoCond = session.Engine.CondDeleted(colName) } } } diff --git a/statement.go b/statement.go index 80fad4f4..6e360bb3 100644 --- a/statement.go +++ b/statement.go @@ -490,224 +490,6 @@ func (statement *Statement) colName(col *core.Column, tableName string) string { return statement.Engine.Quote(col.Name) } -func buildConds(engine *Engine, table *core.Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, - includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { - var conds []builder.Cond - for _, col := range table.Columns() { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - if !includeAutoIncr && col.IsAutoIncrement { - continue - } - - if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { - continue - } - if col.SQLType.IsJson() { - continue - } - - var colName string - if addedTableName { - var nm = tableName - if len(aliasName) > 0 { - nm = aliasName - } - colName = engine.Quote(nm) + "." + engine.Quote(col.Name) - } else { - colName = engine.Quote(col.Name) - } - - fieldValuePtr, err := col.ValueOf(bean) - if err != nil { - engine.logger.Error(err) - continue - } - - if col.IsDeleted && !unscoped { // tag "deleted" is enabled - if engine.dialect.DBType() == core.MSSQL { - conds = append(conds, builder.IsNull{colName}) - } else { - conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})) - } - } - - fieldValue := *fieldValuePtr - if fieldValue.Interface() == nil { - continue - } - - fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := useAllCols - - if b, ok := getFlagForColumn(mustColumnMap, col); ok { - if b { - requiredField = true - } else { - continue - } - } - - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - conds = append(conds, builder.Eq{colName: nil}) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } - - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - t := int64(fieldValue.Uint()) - val = reflect.ValueOf(&t).Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = engine.formatColTime(col, t) - } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil { - continue - } - } else { - if col.SQLType.IsJson() { - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) - } - } else { - val = fieldValue.Interface() - } - } - } - case reflect.Array: - continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.logger.Error(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - - conds = append(conds, builder.Eq{colName: val}) - } - - return builder.And(conds...), nil -} - // TableName return current tableName func (statement *Statement) TableName() string { if statement.AltTableName != "" { @@ -1104,7 +886,7 @@ func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interfa } func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) }