diff --git a/convert.go b/convert.go index 2316ca0b..35e931a6 100644 --- a/convert.go +++ b/convert.go @@ -25,11 +25,10 @@ func strconvErr(err error) error { func cloneBytes(b []byte) []byte { if b == nil { return nil - } else { - c := make([]byte, len(b)) - copy(c, b) - return c } + c := make([]byte, len(b)) + copy(c, b) + return c } func asString(src interface{}) string { @@ -274,11 +273,12 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return vv.String(), nil case reflect.Slice: if tp.Elem().Kind() == reflect.Uint8 { - v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) + return string(vv.Interface().([]byte)), nil + /*v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) if err != nil { return nil, err } - return v, nil + return v, nil*/ } } diff --git a/dialect_postgres.go b/dialect_postgres.go index 2b2a0b78..d992b993 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -820,7 +820,7 @@ func (db *postgres) SqlType(c *core.Column) string { case core.NVarchar: res = core.Varchar case core.Uuid: - res = core.Uuid + return core.Uuid case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: return core.Bytea case core.Double: diff --git a/engine.go b/engine.go index 52ec1e3f..42fa5b63 100644 --- a/engine.go +++ b/engine.go @@ -9,7 +9,6 @@ import ( "bytes" "database/sql" "encoding/gob" - "errors" "fmt" "io" "os" @@ -21,6 +20,7 @@ import ( "github.com/go-xorm/builder" "github.com/go-xorm/core" + "github.com/pkg/errors" ) // Engine is the major struct of xorm, it means a database manager. @@ -799,13 +799,18 @@ func (engine *Engine) UnMapType(t reflect.Type) { } func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { - t := v.Type() engine.mutex.Lock() defer engine.mutex.Unlock() + return engine.autoMapTypeNoLock(v) +} + +func (engine *Engine) autoMapTypeNoLock(v reflect.Value) (*core.Table, error) { + t := v.Type() table, ok := engine.Tables[t] if !ok { var err error - table, err = engine.mapType(v) + var parsingTables = make(map[reflect.Type]*core.Table) + table, err = engine.mapType(parsingTables, v) if err != nil { return nil, err } @@ -879,9 +884,17 @@ var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() ) -func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { +func (engine *Engine) mapType(parsingTables map[reflect.Type]*core.Table, v reflect.Value) (*core.Table, error) { + if v.Kind() != reflect.Struct { + return nil, errors.New("need a struct to map") + } t := v.Type() + if table, ok := parsingTables[t]; ok { + return table, nil + } + table := engine.newTable() + parsingTables[t] = table if tb, ok := v.Interface().(TableName); ok { table.Name = tb.TableName() } else { @@ -908,24 +921,38 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { fieldValue := v.Field(i) fieldType := fieldValue.Type() - if ormTagStr != "" { - col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]int)} - tags := splitTag(ormTagStr) + var ctx = tagContext{ + engine: engine, + parsingTables: parsingTables, + table: table, + hasCacheTag: false, + hasNoCacheTag: false, + + fieldValue: fieldValue, + indexNames: make(map[string]int), + isIndex: false, + isUnique: false, + } + + if ormTagStr != "" { + col = &core.Column{ + FieldName: t.Field(i).Name, + FieldType: t.Field(i).Type, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: core.TWOSIDES, + Indexes: make(map[string]int), + } + ctx.col = col + + tags := splitTag(ormTagStr) if len(tags) > 0 { if tags[0] == "-" { continue } - var ctx = tagContext{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - engine: engine, - } - if strings.ToUpper(tags[0]) == "EXTENDS" { if err := ExtendsTagHandler(&ctx); err != nil { return nil, err @@ -1021,20 +1048,57 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { } else { sqlType = core.Type2SQLType(fieldType) } - col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) + + col = core.NewColumn( + engine.ColumnMapper.Obj2Table(t.Field(i).Name), + t.Field(i).Name, + + sqlType, + sqlType.DefaultLength, + sqlType.DefaultLength2, + true, + ) if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { idFieldColName = col.Name } + + col.FieldType = t.Field(i).Type + ctx.col = col } if col.IsAutoIncrement { col.Nullable = false } - table.AddColumn(col) + var tp = fieldType + if isPtrStruct(fieldType) { + tp = fieldType.Elem() + } + if isStruct(tp) && col.AssociateTable == nil { + var isBelongsTo = !(tp.ConvertibleTo(core.TimeType) || col.SQLType.IsJson()) + if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Interface().(sql.Scanner); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Interface().(core.Conversion); ok { + isBelongsTo = false + } + if isBelongsTo { + err := BelongsToTagHandler(&ctx) + if err != nil { + return nil, err + } + col.AssociateType = core.AssociateNone + } + } + + table.AddColumn(col) } // end for if idFieldColName != "" && len(table.PrimaryKeys) == 0 { @@ -1444,6 +1508,13 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } +// Load loads bean's belongs to tag field immedicatlly +func (engine *Engine) Load(bean interface{}, cols ...string) error { + session := engine.NewSession() + defer session.Close() + return session.Load(bean, cols...) +} + // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct diff --git a/helpers.go b/helpers.go index f39ed472..85240b24 100644 --- a/helpers.go +++ b/helpers.go @@ -16,6 +16,14 @@ import ( "github.com/go-xorm/core" ) +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct || isPtrStruct(t) +} + +func isPtrStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + // str2PK convert string value to primary key value according to tp func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { var err error @@ -96,26 +104,6 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { return v.Interface(), nil } -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 - } - } - } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) - } - return -} - type zeroable interface { IsZero() bool } @@ -471,3 +459,12 @@ func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) return false, false } + +func isStringInSlice(s string, slice []string) bool { + for _, e := range slice { + if s == e { + return true + } + } + return false +} diff --git a/helpers_test.go b/helpers_test.go index d57c54ae..31aec506 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -3,24 +3,3 @@ // license that can be found in the LICENSE file. package xorm - -import "testing" - -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} diff --git a/interface.go b/interface.go index 85a46a27..f727ddd7 100644 --- a/interface.go +++ b/interface.go @@ -18,6 +18,7 @@ type Interface interface { Alias(alias string) *Session Asc(colNames ...string) *Session BufferSize(size int) *Session + Cascade(...bool) *Session Cols(columns ...string) *Session Count(...interface{}) (int64, error) CreateIndexes(bean interface{}) error @@ -27,6 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error + Load(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 5c6cb5f9..3f9656ee 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,13 @@ import ( "github.com/go-xorm/core" ) +type loadClosure struct { + Func func(core.PK, *reflect.Value) error + pk core.PK + fieldValue *reflect.Value + loaded bool +} + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -51,6 +58,9 @@ type Session struct { lastSQL string lastSQLArgs []interface{} + cascadeMode cascadeMode + cascadeLevel int // load level + err error } @@ -82,6 +92,9 @@ func (session *Session) Init() { session.lastSQL = "" session.lastSQLArgs = []interface{}{} + + session.cascadeMode = cascadeCompitable + session.cascadeLevel = 2 } // Close release the connection from pool @@ -149,7 +162,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.UseCascade = false + session.cascadeMode = cascadeLazy return session } @@ -204,9 +217,16 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { + var mode = cascadeEager if len(trueOrFalse) >= 1 { - session.statement.UseCascade = trueOrFalse[0] + if trueOrFalse[0] { + mode = cascadeEager + } else { + mode = cascadeLazy + } } + + session.cascadeMode = mode return session } @@ -440,8 +460,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } - rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) + rawValueType := vv.Type() col := table.GetColumnIdx(key, idx) if col.IsPrimaryKey { pk = append(pk, rawValue.Interface()) @@ -629,175 +649,205 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - } else if col.SQLType.IsJson() { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } + } else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) + var err error + rawValueType := col.AssociateTable.PKColumns()[0].FieldType + if rawValueType.Kind() == reflect.Ptr { + pk[0] = reflect.New(rawValueType.Elem()).Interface() + } else { + pk[0] = reflect.New(rawValueType).Interface() } - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) + err = convertAssign(pk[0], vv.Interface()) if err != nil { return nil, err } + pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + session.cascadeLevel-- hasAssigned = true - if len(table.PrimaryKeys) != 1 { - return nil, errors.New("unsupported non or composited primary key cascade") - } - var pk = make(core.PK, len(table.PrimaryKeys)) - pk[0], err = asKind(vv, rawValueType) + } else if col.AssociateType == core.AssociateBelongsTo { + hasAssigned = true + err := convertAssign(fieldValue.FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), + vv.Interface()) if err != nil { return nil, err } - - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return nil, err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return nil, errors.New("cascade obj is not exist") - } - } } case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct { + if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) + var err error + rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) + if rawValueType.Kind() == reflect.Ptr { + pk[0] = reflect.New(rawValueType.Elem()).Interface() + } else { + pk[0] = reflect.New(rawValueType).Interface() + } + err = convertAssign(pk[0], vv.Interface()) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - case core.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + + pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + + session.cascadeLevel-- + hasAssigned = true + } else if col.AssociateType == core.AssociateBelongsTo { + hasAssigned = true + if fieldValue.IsNil() { + // FIXME: find id column + structInter := reflect.New(fieldValue.Type().Elem()) + fieldValue.Set(structInter) + } + + err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), + vv.Interface()) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) } - hasAssigned = true - } // switch fieldType + } else { + // !nashtsai! TODO merge duplicated codes above + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case core.PtrStringType: + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrBoolType: + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrTimeType: + if rawValueType == core.PtrTimeType { + hasAssigned = true + var x = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat64Type: + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint64Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt64Type: + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat32Type: + if rawValueType.Kind() == reflect.Float64 { + var x = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrIntType: + if rawValueType.Kind() == reflect.Int64 { + var x = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUintType: + if rawValueType.Kind() == reflect.Int64 { + var x = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrComplex64Type: + var x complex64 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.engine.logger.Error(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + } + hasAssigned = true + case core.PtrComplex128Type: + var x complex128 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.engine.logger.Error(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + } + hasAssigned = true + } // switch fieldType + } } // switch fieldType.Kind() // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value @@ -816,6 +866,40 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b return pk, nil } +func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { + if !isPKZero(pk) { + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface()) + if err != nil { + return err + } + if has { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + fmt.Println("getByPK value ptr:", fieldValue.Interface()) + } else if fieldValue.Kind() == reflect.Struct { + fieldValue.Set(structInter.Elem()) + fmt.Println("getByPK value:", fieldValue.Interface()) + } else { + return errors.New("set value failed") + } + } else { + return errors.New("cascade obj is not exist") + } + } + return nil +} + // saveLastSQL stores executed query information func (session *Session) saveLastSQL(sql string, args ...interface{}) { session.lastSQL = sql diff --git a/session_associate.go b/session_associate.go new file mode 100644 index 00000000..62e924a3 --- /dev/null +++ b/session_associate.go @@ -0,0 +1,145 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "reflect" + + "github.com/go-xorm/core" +) + +// Load loads associated fields from database +func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { + v := reflect.ValueOf(beanOrSlices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + return session.loadFind(beanOrSlices, cols...) + } else if v.Kind() == reflect.Struct { + return session.loadGet(beanOrSlices, cols...) + } + return errors.New("unsupported load type, must struct or slice") +} + +// loadFind load 's belongs to tag field immedicatlly +func (session *Session) loadFind(slices interface{}, cols ...string) error { + v := reflect.ValueOf(slices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + return errors.New("only slice is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.engine.autoMapType(vv) + if err != nil { + return err + } + + var pks = make(map[*core.Column][]interface{}) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + /*var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + }*/ + + if !isZero(pk[0]) { + pks[col] = append(pks[col], pk[0]) + } + } + } + } + } + + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) + if err != nil { + return err + } + } + return nil +} + +// loadGet load bean's belongs to tag field immedicatlly +func (session *Session) loadGet(bean interface{}, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + v := rValue(bean) + tb, err := session.engine.autoMapType(v) + if err != nil { + return err + } + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) && session.cascadeLevel > 0 { + has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("load bean does not exist") + } + session.cascadeLevel-- + } + } + } + } + return nil +} diff --git a/session_associate_test.go b/session_associate_test.go new file mode 100644 index 00000000..94624216 --- /dev/null +++ b/session_associate_test.go @@ -0,0 +1,232 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBelongsTo_Get(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face1 struct { + Id int64 + Name string + } + + type Nose1 struct { + Id int64 + Face Face1 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose1), new(Face1)) + assert.NoError(t, err) + + var face = Face1{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face1 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose1{Face: face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose1 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "", cfgNose.Face.Name) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose1 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_GetPtr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face2 struct { + Id int64 + Name string + } + + type Nose2 struct { + Id int64 + Face *Face2 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose2), new(Face2)) + assert.NoError(t, err) + + var face = Face2{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face2 + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, face, cfgFace) + + var nose = Nose2{Face: &face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose2 + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + + err = testEngine.Load(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + + var cfgNose2 Nose2 + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_Find(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face3 struct { + Id int64 + Name string + } + + type Nose3 struct { + Id int64 + Face Face3 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose3), new(Face3)) + assert.NoError(t, err) + + var face1 = Face3{ + Name: "face1", + } + var face2 = Face3{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose3{ + {Face: face1}, + {Face: face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose3 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose3 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses1, "face") + assert.NoError(t, err) + assert.Equal(t, "face1", noses1[0].Face.Name) + assert.Equal(t, "face2", noses1[1].Face.Name) +} + +func TestBelongsTo_FindPtr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face4 struct { + Id int64 + Name string + } + + type Nose4 struct { + Id int64 + Face *Face4 `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose4), new(Face4)) + assert.NoError(t, err) + + var face1 = Face4{ + Name: "face1", + } + var face2 = Face4{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose4{ + {Face: &face1}, + {Face: &face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose4 + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose4 + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.NotNil(t, noses2[0].Face) + assert.NotNil(t, noses2[1].Face) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses2, "face") + assert.NoError(t, err) +} diff --git a/session_convert.go b/session_convert.go index 1f9d8aa1..7ac88d40 100644 --- a/session_convert.go +++ b/session_convert.go @@ -8,7 +8,6 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "errors" "fmt" "reflect" "strconv" @@ -203,39 +202,22 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return err - } - - // TODO: current only support 1 primary key - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(core.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + } else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) + // only 1 PK checked on tag parsing + rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) + var err error pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } + session.cascadeLevel-- + if err = session.getByPK(pk, fieldValue); err != nil { + return err } } } @@ -485,41 +467,23 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(&x)) default: - if session.statement.UseCascade { - structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.autoMapType(structInter.Elem()) - if err != nil { - return err - } - - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(core.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.cascadeMode == cascadeEager)) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) + var err error + // only 1 PK checked on tag parsing + rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } + session.cascadeLevel-- + if err = session.getByPK(pk, fieldValue); err != nil { + return err } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } default: diff --git a/session_find.go b/session_find.go index 79817da3..ffcbf826 100644 --- a/session_find.go +++ b/session_find.go @@ -205,9 +205,8 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va var newElemFunc func(fields []string) reflect.Value elemType := containerValue.Type().Elem() - var isPointer bool - if elemType.Kind() == reflect.Ptr { - isPointer = true + var isPointer = elemType.Kind() == reflect.Ptr + if isPointer { elemType = elemType.Elem() } if elemType.Kind() == reflect.Ptr { @@ -237,7 +236,9 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va if isPointer { containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr())) } else { + fmt.Println("---", newValue.Elem()) containerValue.Set(reflect.Append(containerValue, newValue.Elem())) + fmt.Println("===", containerValue.Interface()) } return nil } diff --git a/session_schema.go b/session_schema.go index 9d9edca8..cd703b43 100644 --- a/session_schema.go +++ b/session_schema.go @@ -265,7 +265,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - table, err := engine.mapType(v) + table, err := engine.autoMapType(v) if err != nil { return err } diff --git a/statement.go b/statement.go index 02d73559..1ba7e517 100644 --- a/statement.go +++ b/statement.go @@ -33,6 +33,15 @@ type exprParam struct { expr string } +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything +) + // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -54,7 +63,6 @@ type Statement struct { tableName string RawSQL string RawParams []interface{} - UseCascade bool UseAutoJoin bool StoreEngine string Charset string @@ -82,7 +90,6 @@ func (statement *Statement) Init() { statement.Start = 0 statement.LimitN = 0 statement.OrderStr = "" - statement.UseCascade = true statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" diff --git a/tag.go b/tag.go index e1c821fb..40405cc7 100644 --- a/tag.go +++ b/tag.go @@ -5,6 +5,7 @@ package xorm import ( + "errors" "fmt" "reflect" "strconv" @@ -15,46 +16,71 @@ import ( ) type tagContext struct { + engine *Engine + parsingTables map[reflect.Type]*core.Table + + table *core.Table + hasCacheTag bool + hasNoCacheTag bool + + col *core.Column + fieldValue reflect.Value + isIndex bool + isUnique bool + indexNames map[string]int + tagName string params []string preTag, nextTag string - table *core.Table - col *core.Column - fieldValue reflect.Value - isIndex bool - isUnique bool - indexNames map[string]int - engine *Engine - hasCacheTag bool - hasNoCacheTag bool ignoreNext bool } +func splitTag(tag string) (tags []string) { + tag = strings.TrimSpace(tag) + var hasQuote = false + var lastIdx = 0 + for i, t := range tag { + if t == '\'' { + hasQuote = !hasQuote + } else if t == ' ' { + if lastIdx < i && !hasQuote { + tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) + lastIdx = i + 1 + } + } + } + if lastIdx < len(tag) { + tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + } + return +} + // tagHandler describes tag handler for XORM type tagHandler func(ctx *tagContext) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]tagHandler{ - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": IgnoreTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": IgnoreTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "BELONGS_TO": BelongsToTagHandler, } ) @@ -256,15 +282,16 @@ func ExtendsTagHandler(ctx *tagContext) error { } fallthrough case reflect.Struct: - parentTable, err := ctx.engine.mapType(fieldValue) + parentTable, err := ctx.engine.mapType(ctx.parsingTables, fieldValue) if err != nil { return err } - for _, col := range parentTable.Columns() { + for _, oriCol := range parentTable.Columns() { + col := *oriCol col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) - ctx.table.AddColumn(col) + ctx.table.AddColumn(&col) for indexName, indexType := range col.Indexes { - addIndex(indexName, ctx.table, col, indexType) + addIndex(indexName, ctx.table, &col, indexType) } } default: @@ -288,3 +315,44 @@ func NoCacheTagHandler(ctx *tagContext) error { } return nil } + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *tagContext) error { + if !isStruct(ctx.fieldValue.Type()) { + return errors.New("Tag belongs_to cannot be applied on non-struct field") + } + + ctx.col.AssociateType = core.AssociateBelongsTo + var t reflect.Value + if ctx.fieldValue.Kind() == reflect.Struct { + t = ctx.fieldValue + } else { + if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct { + if ctx.fieldValue.IsNil() { + t = reflect.New(ctx.fieldValue.Type().Elem()).Elem() + } else { + t = ctx.fieldValue + } + } else { + return errors.New("Only struct or ptr to struct field could add belongs_to flag") + } + } + + belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + if err != nil { + return err + } + pks := belongsT.PKColumns() + if len(pks) != 1 { + panic("unsupported non or composited primary key cascade") + return errors.New("blongs_to only should be as a tag of table has one primary key") + } + + ctx.col.AssociateTable = belongsT + ctx.col.SQLType = pks[0].SQLType + + if len(ctx.col.Name) == 0 { + ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + } + return nil +} diff --git a/tag_extends_test.go b/tag_extends_test.go index b70eefe3..f0af21d7 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -169,22 +169,12 @@ func TestExtends(t *testing.T) { tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) users := make([]tempUser3, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - err = errors.New("error get data not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) assertSync(t, new(Userinfo), new(Userdetail)) @@ -209,21 +199,9 @@ func TestExtends(t *testing.T) { sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) - if err != nil { - t.Error(err) - panic(err) - } - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - fmt.Println(info) - if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, b) + assert.False(t, info.Userinfo.Uid == 0 || info.Userdetail.Id == 0) fmt.Println("----join--info2") var info2 UserAndDetail @@ -536,3 +514,55 @@ func TestExtends4(t *testing.T) { panic(err) } } + +func TestExtendsTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + table := testEngine.TableInfo(new(Userdetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 3, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[1]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[2]) + + table = testEngine.TableInfo(new(Userinfo)) + assert.NotNil(t, table) + assert.EqualValues(t, 8, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + + table = testEngine.TableInfo(new(UserAndDetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 11, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + assert.EqualValues(t, "id", table.ColumnsSeq()[8]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[9]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[10]) + + cols := table.Columns() + assert.EqualValues(t, 11, len(cols)) + assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName) + assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName) + assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName) + assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName) + assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName) + assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName) + assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName) + assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName) + assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName) + assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName) + assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName) +} diff --git a/tag_test.go b/tag_test.go index c9b76048..f97f6a10 100644 --- a/tag_test.go +++ b/tag_test.go @@ -393,3 +393,22 @@ func TestTagTime(t *testing.T) { assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) } + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !sliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +} diff --git a/xorm_test.go b/xorm_test.go index 569bc681..d86fedf0 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -12,6 +12,7 @@ import ( "github.com/go-xorm/core" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" _ "github.com/ziutek/mymysql/godrv" ) @@ -32,10 +33,8 @@ var ( func createEngine(dbType, connStr string) error { if testEngine == nil { var err error - if !*cluster { testEngine, err = NewEngine(dbType, connStr) - } else { testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) } @@ -123,6 +122,8 @@ func TestMain(m *testing.M) { } func TestPing(t *testing.T) { + assert.NoError(t, prepareEngine()) + if err := testEngine.Ping(); err != nil { t.Fatal(err) }