From 557d5a4101df9e1f7eeffc3696a43ccf2d378514 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 14 Mar 2017 22:25:10 +0800 Subject: [PATCH 01/12] implement simple belongs_to tag --- blongs_to_test.go | 149 ++++++++++++++++++++++++++++++++++++++++++++ dialect_postgres.go | 2 +- engine.go | 101 ++++++++++++++++++++++++------ helpers.go | 8 +++ session.go | 26 +++++--- session_convert.go | 27 ++++---- session_schema.go | 2 +- statement.go | 12 +++- tag.go | 105 ++++++++++++++++++++++--------- xorm_test.go | 3 + 10 files changed, 359 insertions(+), 76 deletions(-) create mode 100644 blongs_to_test.go diff --git a/blongs_to_test.go b/blongs_to_test.go new file mode 100644 index 00000000..b6d8b1b5 --- /dev/null +++ b/blongs_to_test.go @@ -0,0 +1,149 @@ +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBelongsTo_Get(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face = Face{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgFace, face) + + var nose = Nose{Face: face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose.Id, nose.Id) + // FIXME: the id should be set back to the field + //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, "", cfgNose.Face.Name) + + var cfgNose2 Nose + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose2.Id, nose.Id) + assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_GetPtr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face *Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face = Face{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgFace, face) + + var nose = Nose{Face: &face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose.Id, nose.Id) + // FIXME: the id should be set back to the field + //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + + var cfgNose2 Nose + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose2.Id, nose.Id) + assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_Find(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face1 = Face{ + Name: "face1", + } + var face2 = Face{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose{ + {Face: face1}, + {Face: face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + // FIXME: + //assert.Equal(t, face1.Id, noses1[0].Face.Id) + //assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) +} diff --git a/dialect_postgres.go b/dialect_postgres.go index 3f5c526f..0f9679c7 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -811,7 +811,7 @@ func (db *postgres) SqlType(c *core.Column) string { case core.NVarchar: res = core.Varchar case core.Uuid: - res = core.Uuid + return core.Uuid case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: return core.Bytea case core.Double: diff --git a/engine.go b/engine.go index 2b986966..5ca568e1 100644 --- a/engine.go +++ b/engine.go @@ -9,7 +9,6 @@ import ( "bytes" "database/sql" "encoding/gob" - "errors" "fmt" "io" "os" @@ -21,6 +20,7 @@ import ( "github.com/go-xorm/builder" "github.com/go-xorm/core" + "github.com/pkg/errors" ) // Engine is the major struct of xorm, it means a database manager. @@ -792,13 +792,18 @@ func (engine *Engine) UnMapType(t reflect.Type) { } func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { - t := v.Type() engine.mutex.Lock() defer engine.mutex.Unlock() + return engine.autoMapTypeNoLock(v) +} + +func (engine *Engine) autoMapTypeNoLock(v reflect.Value) (*core.Table, error) { + t := v.Type() table, ok := engine.Tables[t] if !ok { var err error - table, err = engine.mapType(v) + var parsingTables = make(map[reflect.Type]*core.Table) + table, err = engine.mapType(parsingTables, v) if err != nil { return nil, err } @@ -872,9 +877,17 @@ var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() ) -func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { +func (engine *Engine) mapType(parsingTables map[reflect.Type]*core.Table, v reflect.Value) (*core.Table, error) { + if v.Kind() != reflect.Struct { + return nil, errors.New("need a struct to map") + } t := v.Type() + if table, ok := parsingTables[t]; ok { + return table, nil + } + table := engine.newTable() + parsingTables[t] = table if tb, ok := v.Interface().(TableName); ok { table.Name = tb.TableName() } else { @@ -901,24 +914,37 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { fieldValue := v.Field(i) fieldType := fieldValue.Type() - if ormTagStr != "" { - col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]int)} - tags := splitTag(ormTagStr) + var ctx = tagContext{ + engine: engine, + parsingTables: parsingTables, + table: table, + hasCacheTag: false, + hasNoCacheTag: false, + + fieldValue: fieldValue, + indexNames: make(map[string]int), + isIndex: false, + isUnique: false, + } + + if ormTagStr != "" { + col = &core.Column{ + FieldName: t.Field(i).Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: core.TWOSIDES, + Indexes: make(map[string]int), + } + ctx.col = col + + tags := splitTag(ormTagStr) if len(tags) > 0 { if tags[0] == "-" { continue } - var ctx = tagContext{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - engine: engine, - } - if strings.ToUpper(tags[0]) == "EXTENDS" { if err := ExtendsTagHandler(&ctx); err != nil { return nil, err @@ -1014,20 +1040,55 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { } else { sqlType = core.Type2SQLType(fieldType) } - col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) + + col = core.NewColumn( + engine.ColumnMapper.Obj2Table(t.Field(i).Name), + t.Field(i).Name, + sqlType, + sqlType.DefaultLength, + sqlType.DefaultLength2, + true, + ) if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { idFieldColName = col.Name } + + ctx.col = col } if col.IsAutoIncrement { col.Nullable = false } - table.AddColumn(col) + var tp = fieldType + if isPtrStruct(fieldType) { + tp = fieldType.Elem() + } + if isStruct(tp) && col.AssociateTable == nil { + var isBelongsTo = !(tp.ConvertibleTo(core.TimeType) || col.SQLType.IsJson()) + if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Interface().(sql.Scanner); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Interface().(core.Conversion); ok { + isBelongsTo = false + } + if isBelongsTo { + err := BelongsToTagHandler(&ctx) + if err != nil { + return nil, err + } + col.AssociateType = core.AssociateNone + } + } + + table.AddColumn(col) } // end for if idFieldColName != "" && len(table.PrimaryKeys) == 0 { diff --git a/helpers.go b/helpers.go index f39ed472..b505a995 100644 --- a/helpers.go +++ b/helpers.go @@ -16,6 +16,14 @@ import ( "github.com/go-xorm/core" ) +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct || isPtrStruct(t) +} + +func isPtrStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + // str2PK convert string value to primary key value according to tp func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { var err error diff --git a/session.go b/session.go index 5c6cb5f9..789ec35c 100644 --- a/session.go +++ b/session.go @@ -149,7 +149,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.UseCascade = false + session.statement.cascadeMode = cascadeManuallyLoad return session } @@ -204,9 +204,16 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { + var mode = cascadeAutoLoad if len(trueOrFalse) >= 1 { - session.statement.UseCascade = trueOrFalse[0] + if trueOrFalse[0] { + mode = cascadeAutoLoad + } else { + mode = cascadeManuallyLoad + } } + + session.statement.cascadeMode = mode return session } @@ -629,7 +636,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - } else if col.SQLType.IsJson() { + /*} else if col.SQLType.IsJson() { if rawValueType.Kind() == reflect.String { hasAssigned = true x := reflect.New(fieldType) @@ -650,18 +657,19 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } fieldValue.Set(x.Elem()) } - } - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return nil, err - } + }*/ + } else if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable hasAssigned = true if len(table.PrimaryKeys) != 1 { return nil, errors.New("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) + var err error pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err diff --git a/session_convert.go b/session_convert.go index f2c949ba..1e6355c3 100644 --- a/session_convert.go +++ b/session_convert.go @@ -203,11 +203,11 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return err - } + } else if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable // TODO: current only support 1 primary key if len(table.PrimaryKeys) > 1 { @@ -216,6 +216,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + var err error pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err @@ -485,18 +486,17 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(&x)) default: - if session.statement.UseCascade { - structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.autoMapType(structInter.Elem()) - if err != nil { - return err - } - + if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable if len(table.PrimaryKeys) > 1 { return errors.New("unsupported composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) + var err error rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { @@ -504,6 +504,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } if !isPKZero(pk) { + structInter := reflect.New(fieldType.Elem()) // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -518,8 +519,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return errors.New("cascade obj is not exist") } } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } default: diff --git a/session_schema.go b/session_schema.go index a2708b73..8960cdfe 100644 --- a/session_schema.go +++ b/session_schema.go @@ -259,7 +259,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - table, err := engine.mapType(v) + table, err := engine.autoMapType(v) if err != nil { return err } diff --git a/statement.go b/statement.go index 23346c71..4f4631ea 100644 --- a/statement.go +++ b/statement.go @@ -33,6 +33,14 @@ type exprParam struct { expr string } +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota + cascadeAutoLoad + cascadeManuallyLoad +) + // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -54,7 +62,7 @@ type Statement struct { tableName string RawSQL string RawParams []interface{} - UseCascade bool + cascadeMode cascadeMode UseAutoJoin bool StoreEngine string Charset string @@ -82,7 +90,7 @@ func (statement *Statement) Init() { statement.Start = 0 statement.LimitN = 0 statement.OrderStr = "" - statement.UseCascade = true + statement.cascadeMode = cascadeCompitable statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" diff --git a/tag.go b/tag.go index e1c821fb..835c65f3 100644 --- a/tag.go +++ b/tag.go @@ -5,6 +5,7 @@ package xorm import ( + "errors" "fmt" "reflect" "strconv" @@ -15,18 +16,22 @@ import ( ) type tagContext struct { + engine *Engine + parsingTables map[reflect.Type]*core.Table + + table *core.Table + hasCacheTag bool + hasNoCacheTag bool + + col *core.Column + fieldValue reflect.Value + isIndex bool + isUnique bool + indexNames map[string]int + tagName string params []string preTag, nextTag string - table *core.Table - col *core.Column - fieldValue reflect.Value - isIndex bool - isUnique bool - indexNames map[string]int - engine *Engine - hasCacheTag bool - hasNoCacheTag bool ignoreNext bool } @@ -36,25 +41,26 @@ type tagHandler func(ctx *tagContext) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]tagHandler{ - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": IgnoreTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": IgnoreTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "BELONGS_TO": BelongsToTagHandler, } ) @@ -256,7 +262,7 @@ func ExtendsTagHandler(ctx *tagContext) error { } fallthrough case reflect.Struct: - parentTable, err := ctx.engine.mapType(fieldValue) + parentTable, err := ctx.engine.mapType(ctx.parsingTables, fieldValue) if err != nil { return err } @@ -288,3 +294,44 @@ func NoCacheTagHandler(ctx *tagContext) error { } return nil } + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *tagContext) error { + if !isStruct(ctx.fieldValue.Type()) { + return errors.New("Tag belongs_to cannot be applied on non-struct field") + } + + ctx.col.AssociateType = core.AssociateBelongsTo + var t reflect.Value + if ctx.fieldValue.Kind() == reflect.Struct { + t = ctx.fieldValue + } else { + if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct { + if ctx.fieldValue.IsNil() { + t = reflect.New(ctx.fieldValue.Type().Elem()).Elem() + } else { + t = ctx.fieldValue + } + } else { + return errors.New("Only struct or ptr to struct field could add belongs_to flag") + } + } + + belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + if err != nil { + return err + } + pks := belongsT.PKColumns() + if len(pks) != 1 { + panic("unsupported non or composited primary key cascade") + return errors.New("blongs_to only should be as a tag of table has one primary key") + } + + ctx.col.AssociateTable = belongsT + ctx.col.SQLType = pks[0].SQLType + + if len(ctx.col.Name) == 0 { + ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + } + return nil +} diff --git a/xorm_test.go b/xorm_test.go index 569bc681..f7581e06 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -12,6 +12,7 @@ import ( "github.com/go-xorm/core" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" _ "github.com/ziutek/mymysql/godrv" ) @@ -123,6 +124,8 @@ func TestMain(m *testing.M) { } func TestPing(t *testing.T) { + assert.NoError(t, prepareEngine()) + if err := testEngine.Ping(); err != nil { t.Fatal(err) } From 8a3fa4464ddc59806cb39202ae3edb57290a07b6 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 26 Mar 2017 09:47:35 +0800 Subject: [PATCH 02/12] add cascade find test and vendor folder --- blongs_to_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/blongs_to_test.go b/blongs_to_test.go index b6d8b1b5..4de7a3d0 100644 --- a/blongs_to_test.go +++ b/blongs_to_test.go @@ -146,4 +146,14 @@ func TestBelongsTo_Find(t *testing.T) { //assert.Equal(t, face2.Id, noses1[1].Face.Id) assert.Equal(t, "", noses1[0].Face.Name) assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + // FIXME: + //assert.Equal(t, face1.Id, noses2[0].Face.Id) + //assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) } From 8490767f1e17ca7cbd2c4c7f67d9a03ac3097923 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 4 Apr 2017 12:19:37 +0800 Subject: [PATCH 03/12] set column value to belongs_to bean when cascade is disabled --- blongs_to_test.go | 26 ++-- convert.go | 12 +- engine.go | 7 + interface.go | 2 + session.go | 305 +++++++++++++++++++++++++------------------ session_associate.go | 58 ++++++++ xorm_test.go | 2 - 7 files changed, 268 insertions(+), 144 deletions(-) create mode 100644 session_associate.go diff --git a/blongs_to_test.go b/blongs_to_test.go index 4de7a3d0..eb0d93e1 100644 --- a/blongs_to_test.go +++ b/blongs_to_test.go @@ -43,16 +43,21 @@ func TestBelongsTo_Get(t *testing.T) { assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, cfgNose.Id, nose.Id) - // FIXME: the id should be set back to the field - //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) + err = testEngine.EagerLoad(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, cfgNose.Id, nose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) + var cfgNose2 Nose has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, cfgNose2.Id, nose.Id) - assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) assert.Equal(t, "face1", cfgNose2.Face.Name) } @@ -92,16 +97,21 @@ func TestBelongsTo_GetPtr(t *testing.T) { has, err = testEngine.Get(&cfgNose) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgNose.Id, nose.Id) - // FIXME: the id should be set back to the field - //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + + err = testEngine.EagerLoad(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) + assert.Equal(t, "face1", cfgNose.Face.Name) var cfgNose2 Nose has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgNose2.Id, nose.Id) - assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Id, cfgNose2.Id) + assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) assert.Equal(t, "face1", cfgNose2.Face.Name) } diff --git a/convert.go b/convert.go index 0504bef1..db6c7fab 100644 --- a/convert.go +++ b/convert.go @@ -25,11 +25,10 @@ func strconvErr(err error) error { func cloneBytes(b []byte) []byte { if b == nil { return nil - } else { - c := make([]byte, len(b)) - copy(c, b) - return c } + c := make([]byte, len(b)) + copy(c, b) + return c } func asString(src interface{}) string { @@ -274,11 +273,12 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { return vv.String(), nil case reflect.Slice: if tp.Elem().Kind() == reflect.Uint8 { - v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) + return string(vv.Interface().([]byte)), nil + /*v, err := strconv.ParseInt(string(vv.Interface().([]byte)), 10, 64) if err != nil { return nil, err } - return v, nil + return v, nil*/ } } diff --git a/engine.go b/engine.go index 5ca568e1..c3e0d8d0 100644 --- a/engine.go +++ b/engine.go @@ -1498,6 +1498,13 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } +// EagerLoad loads bean's belongs to tag field immedicatlly +func (engine *Engine) EagerLoad(bean interface{}, cols ...string) error { + session := engine.NewSession() + defer session.Close() + return session.EagerLoad(bean, cols...) +} + // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct diff --git a/interface.go b/interface.go index 4f94750b..4adee92e 100644 --- a/interface.go +++ b/interface.go @@ -18,6 +18,7 @@ type Interface interface { Alias(alias string) *Session Asc(colNames ...string) *Session BufferSize(size int) *Session + Cascade(...bool) *Session Cols(columns ...string) *Session Count(...interface{}) (int64, error) CreateIndexes(bean interface{}) error @@ -27,6 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error + EagerLoad(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 789ec35c..d7677163 100644 --- a/session.go +++ b/session.go @@ -636,28 +636,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - /*} else if col.SQLType.IsJson() { - if rawValueType.Kind() == reflect.String { - hasAssigned = true - x := reflect.New(fieldType) - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - } else if rawValueType.Kind() == reflect.Slice { - hasAssigned = true - x := reflect.New(fieldType) - if len(vv.Bytes()) > 0 { - err := json.Unmarshal(vv.Bytes(), x.Interface()) - if err != nil { - return nil, err - } - fieldValue.Set(x.Elem()) - } - }*/ } else if (col.AssociateType == core.AssociateNone && session.statement.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && @@ -690,122 +668,193 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b return nil, errors.New("cascade obj is not exist") } } + } else if col.AssociateType == core.AssociateBelongsTo { + hasAssigned = true + err := convertAssign(fieldValue.FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), + vv.Interface()) + if err != nil { + return nil, err + } } case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - switch fieldType { - // following types case matching ptr's native type, therefore assign ptr directly - case core.PtrStringType: - if rawValueType.Kind() == reflect.String { - x := vv.String() + if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct { + if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable + hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrBoolType: - if rawValueType.Kind() == reflect.Bool { - x := vv.Bool() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrTimeType: - if rawValueType == core.PtrTimeType { - hasAssigned = true - var x = rawValue.Interface().(time.Time) - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat64Type: - if rawValueType.Kind() == reflect.Float64 { - x := vv.Float() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint64Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint64(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt64Type: - if rawValueType.Kind() == reflect.Int64 { - x := vv.Int() - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrFloat32Type: - if rawValueType.Kind() == reflect.Float64 { - var x = float32(vv.Float()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrIntType: - if rawValueType.Kind() == reflect.Int64 { - var x = int(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrInt16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = int16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUintType: - if rawValueType.Kind() == reflect.Int64 { - var x = uint(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.PtrUint32Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint32(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint8Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint8(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Uint16Type: - if rawValueType.Kind() == reflect.Int64 { - var x = uint16(vv.Int()) - hasAssigned = true - fieldValue.Set(reflect.ValueOf(&x)) - } - case core.Complex64Type: - var x complex64 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + if len(table.PrimaryKeys) != 1 { + panic("unsupported non or composited primary key cascade") + } + var pk = make(core.PK, len(table.PrimaryKeys)) + var err error + pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) - } - hasAssigned = true - case core.Complex128Type: - var x complex128 - if len([]byte(vv.String())) > 0 { - err := json.Unmarshal([]byte(vv.String()), &x) + + if !isPKZero(pk) { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return nil, err + } + if has { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + } + } else { + return nil, errors.New("cascade obj is not exist") + } + } + } else if col.AssociateType == core.AssociateBelongsTo { + hasAssigned = true + if fieldValue.IsNil() { + structInter := reflect.New(fieldValue.Type().Elem()) + fieldValue.Set(structInter) + } + + //fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Set(vv) + err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), + vv.Interface()) if err != nil { return nil, err } - fieldValue.Set(reflect.ValueOf(&x)) } - hasAssigned = true - } // switch fieldType + } else { + // !nashtsai! TODO merge duplicated codes above + //typeStr := fieldType.String() + switch fieldType { + // following types case matching ptr's native type, therefore assign ptr directly + case core.PtrStringType: + if rawValueType.Kind() == reflect.String { + x := vv.String() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrBoolType: + if rawValueType.Kind() == reflect.Bool { + x := vv.Bool() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrTimeType: + if rawValueType == core.PtrTimeType { + hasAssigned = true + var x = rawValue.Interface().(time.Time) + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat64Type: + if rawValueType.Kind() == reflect.Float64 { + x := vv.Float() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint64Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint64(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt64Type: + if rawValueType.Kind() == reflect.Int64 { + x := vv.Int() + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrFloat32Type: + if rawValueType.Kind() == reflect.Float64 { + var x = float32(vv.Float()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrIntType: + if rawValueType.Kind() == reflect.Int64 { + var x = int(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrInt16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = int16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUintType: + if rawValueType.Kind() == reflect.Int64 { + var x = uint(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint32Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint32(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint8Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint8(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrUint16Type: + if rawValueType.Kind() == reflect.Int64 { + var x = uint16(vv.Int()) + hasAssigned = true + fieldValue.Set(reflect.ValueOf(&x)) + } + case core.PtrComplex64Type: + var x complex64 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.engine.logger.Error(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + } + hasAssigned = true + case core.PtrComplex128Type: + var x complex128 + if len([]byte(vv.String())) > 0 { + err := json.Unmarshal([]byte(vv.String()), &x) + if err != nil { + session.engine.logger.Error(err) + } else { + fieldValue.Set(reflect.ValueOf(&x)) + } + } + hasAssigned = true + } // switch fieldType + } } // switch fieldType.Kind() // !nashtsai! for value can't be assigned directly fallback to convert to []byte then back to value diff --git a/session_associate.go b/session_associate.go new file mode 100644 index 00000000..8e2c6231 --- /dev/null +++ b/session_associate.go @@ -0,0 +1,58 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "errors" + "reflect" + + "github.com/go-xorm/core" +) + +// EagerLoad load bean's belongs to tag field immedicatlly +func (session *Session) EagerLoad(bean interface{}, cols ...string) error { + if session.isAutoClose { + defer session.Close() + } + + v := rValue(bean) + tb, err := session.engine.autoMapType(v) + if err != nil { + return err + } + + for _, col := range tb.Columns() { + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&v) + if err != nil { + return err + } + + pk, err := session.engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) { + has, err := session.ID(pk).get(colPtr.Interface()) + if err != nil { + return err + } + if !has { + return errors.New("load bean does not exist") + } + } + } + } + } + return nil +} diff --git a/xorm_test.go b/xorm_test.go index f7581e06..d86fedf0 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -33,10 +33,8 @@ var ( func createEngine(dbType, connStr string) error { if testEngine == nil { var err error - if !*cluster { testEngine, err = NewEngine(dbType, connStr) - } else { testEngine, err = NewEngineGroup(dbType, strings.Split(connStr, *splitter)) } From c68711d9a5fa92ac99e81364e5c19c850a1e9903 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 4 Apr 2017 12:34:56 +0800 Subject: [PATCH 04/12] improved test for find --- blongs_to_test.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/blongs_to_test.go b/blongs_to_test.go index eb0d93e1..b5f3983a 100644 --- a/blongs_to_test.go +++ b/blongs_to_test.go @@ -32,7 +32,7 @@ func TestBelongsTo_Get(t *testing.T) { has, err := testEngine.Get(&cfgFace) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgFace, face) + assert.Equal(t, face, cfgFace) var nose = Nose{Face: face} _, err = testEngine.Insert(&nose) @@ -42,13 +42,13 @@ func TestBelongsTo_Get(t *testing.T) { has, err = testEngine.Get(&cfgNose) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgNose.Id, nose.Id) - assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, nose.Id, cfgNose.Id) + assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) err = testEngine.EagerLoad(&cfgNose) assert.NoError(t, err) - assert.Equal(t, cfgNose.Id, nose.Id) + assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "face1", cfgNose.Face.Name) @@ -56,7 +56,7 @@ func TestBelongsTo_Get(t *testing.T) { has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgNose2.Id, nose.Id) + assert.Equal(t, nose.Id, cfgNose2.Id) assert.Equal(t, nose.Face.Id, cfgNose2.Face.Id) assert.Equal(t, "face1", cfgNose2.Face.Name) } @@ -87,7 +87,7 @@ func TestBelongsTo_GetPtr(t *testing.T) { has, err := testEngine.Get(&cfgFace) assert.NoError(t, err) assert.Equal(t, true, has) - assert.Equal(t, cfgFace, face) + assert.Equal(t, face, cfgFace) var nose = Nose{Face: &face} _, err = testEngine.Insert(&nose) @@ -151,9 +151,8 @@ func TestBelongsTo_Find(t *testing.T) { err = testEngine.Find(&noses1) assert.NoError(t, err) assert.Equal(t, 2, len(noses1)) - // FIXME: - //assert.Equal(t, face1.Id, noses1[0].Face.Id) - //assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) assert.Equal(t, "", noses1[0].Face.Name) assert.Equal(t, "", noses1[1].Face.Name) @@ -161,9 +160,8 @@ func TestBelongsTo_Find(t *testing.T) { err = testEngine.Cascade().Find(&noses2) assert.NoError(t, err) assert.Equal(t, 2, len(noses2)) - // FIXME: - //assert.Equal(t, face1.Id, noses2[0].Face.Id) - //assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) assert.Equal(t, "face1", noses2[0].Face.Name) assert.Equal(t, "face2", noses2[1].Face.Name) } From 590bb1015bc1b3f3e51d9a0ea6d861705b12cd4b Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 4 Apr 2017 12:39:48 +0800 Subject: [PATCH 05/12] add FindPtr for belongs_to --- blongs_to_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/blongs_to_test.go b/blongs_to_test.go index b5f3983a..9ab2faa0 100644 --- a/blongs_to_test.go +++ b/blongs_to_test.go @@ -165,3 +165,54 @@ func TestBelongsTo_Find(t *testing.T) { assert.Equal(t, "face1", noses2[0].Face.Name) assert.Equal(t, "face2", noses2[1].Face.Name) } + +func TestBelongsTo_FindPtr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face *Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face1 = Face{ + Name: "face1", + } + var face2 = Face{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose{ + {Face: &face1}, + {Face: &face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + assert.Equal(t, face1.Id, noses1[0].Face.Id) + assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) + + var noses2 []Nose + err = testEngine.Cascade().Find(&noses2) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses2)) + assert.Equal(t, face1.Id, noses2[0].Face.Id) + assert.Equal(t, face2.Id, noses2[1].Face.Id) + assert.Equal(t, "face1", noses2[0].Face.Name) + assert.Equal(t, "face2", noses2[1].Face.Name) +} From 0abbd9fb91d77866682e785f36deb5a02dd71e29 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 4 Apr 2017 17:05:49 +0800 Subject: [PATCH 06/12] improved get cascade bean --- blongs_to_test.go | 4 ++ engine.go | 3 ++ session.go | 122 ++++++++++++++++++++++----------------------- session_convert.go | 61 +++++------------------ session_find.go | 5 +- statement.go | 4 +- 6 files changed, 82 insertions(+), 117 deletions(-) diff --git a/blongs_to_test.go b/blongs_to_test.go index 9ab2faa0..8a0cbfcb 100644 --- a/blongs_to_test.go +++ b/blongs_to_test.go @@ -1,3 +1,7 @@ +// Copyright 2017 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( diff --git a/engine.go b/engine.go index c3e0d8d0..1c0efc51 100644 --- a/engine.go +++ b/engine.go @@ -931,6 +931,7 @@ func (engine *Engine) mapType(parsingTables map[reflect.Type]*core.Table, v refl if ormTagStr != "" { col = &core.Column{ FieldName: t.Field(i).Name, + FieldType: t.Field(i).Type, Nullable: true, IsPrimaryKey: false, IsAutoIncrement: false, @@ -1044,6 +1045,7 @@ func (engine *Engine) mapType(parsingTables map[reflect.Type]*core.Table, v refl col = core.NewColumn( engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, + sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, @@ -1054,6 +1056,7 @@ func (engine *Engine) mapType(parsingTables map[reflect.Type]*core.Table, v refl idFieldColName = col.Name } + col.FieldType = t.Field(i).Type ctx.col = col } if col.IsAutoIncrement { diff --git a/session.go b/session.go index d7677163..830011bb 100644 --- a/session.go +++ b/session.go @@ -149,7 +149,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.cascadeMode = cascadeManuallyLoad + session.statement.cascadeMode = cascadeManually return session } @@ -204,12 +204,12 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { - var mode = cascadeAutoLoad + var mode = cascadeAuto if len(trueOrFalse) >= 1 { if trueOrFalse[0] { - mode = cascadeAutoLoad + mode = cascadeAuto } else { - mode = cascadeManuallyLoad + mode = cascadeManually } } @@ -447,8 +447,8 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b continue } - rawValueType := reflect.TypeOf(rawValue.Interface()) vv := reflect.ValueOf(rawValue.Interface()) + rawValueType := vv.Type() col := table.GetColumnIdx(key, idx) if col.IsPrimaryKey { pk = append(pk, rawValue.Interface()) @@ -639,35 +639,25 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } else if (col.AssociateType == core.AssociateNone && session.statement.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAutoLoad) { - table := col.AssociateTable - - hasAssigned = true - if len(table.PrimaryKeys) != 1 { - return nil, errors.New("unsupported non or composited primary key cascade") - } - var pk = make(core.PK, len(table.PrimaryKeys)) + session.statement.cascadeMode == cascadeAuto) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error - pk[0], err = asKind(vv, rawValueType) + rawValueType := col.AssociateTable.PKColumns()[0].FieldType + if rawValueType.Kind() == reflect.Ptr { + pk[0] = reflect.New(rawValueType.Elem()).Interface() + } else { + pk[0] = reflect.New(rawValueType).Interface() + } + err = convertAssign(pk[0], vv.Interface()) if err != nil { return nil, err } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return nil, err - } - if has { - fieldValue.Set(structInter.Elem()) - } else { - return nil, errors.New("cascade obj is not exist") - } + pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() + if err = session.getByPK(pk, fieldValue); err != nil { + return nil, err } + hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true err := convertAssign(fieldValue.FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), @@ -681,47 +671,25 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if (col.AssociateType == core.AssociateNone && session.statement.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAutoLoad) { - table := col.AssociateTable - - hasAssigned = true - if len(table.PrimaryKeys) != 1 { - panic("unsupported non or composited primary key cascade") - } - var pk = make(core.PK, len(table.PrimaryKeys)) + session.statement.cascadeMode == cascadeAuto) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error - pk[0], err = asKind(vv, rawValueType) + rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) + if rawValueType.Kind() == reflect.Ptr { + pk[0] = reflect.New(rawValueType.Elem()).Interface() + } else { + pk[0] = reflect.New(rawValueType).Interface() + } + err = convertAssign(pk[0], vv.Interface()) if err != nil { return nil, err } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - var structInter reflect.Value - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - structInter = reflect.New(fieldValue.Type().Elem()) - } else { - structInter = *fieldValue - } - } else { - structInter = fieldValue.Addr() - } - - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return nil, err - } - if has { - if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { - fieldValue.Set(structInter) - } - } else { - return nil, errors.New("cascade obj is not exist") - } + pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() + if err = session.getByPK(pk, fieldValue); err != nil { + return nil, err } + hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true if fieldValue.IsNil() { @@ -873,6 +841,34 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b return pk, nil } +func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { + if !isPKZero(pk) { + var structInter reflect.Value + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + structInter = reflect.New(fieldValue.Type().Elem()) + } else { + structInter = *fieldValue + } + } else { + structInter = fieldValue.Addr() + } + + has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + if err != nil { + return err + } + if has { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(structInter) + } + } else { + return errors.New("cascade obj is not exist") + } + } + return nil +} + // saveLastSQL stores executed query information func (session *Session) saveLastSQL(sql string, args ...interface{}) { session.lastSQL = sql diff --git a/session_convert.go b/session_convert.go index 1e6355c3..e69788b8 100644 --- a/session_convert.go +++ b/session_convert.go @@ -8,7 +8,6 @@ import ( "database/sql" "database/sql/driver" "encoding/json" - "errors" "fmt" "reflect" "strconv" @@ -206,37 +205,18 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else if (col.AssociateType == core.AssociateNone && session.statement.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAutoLoad) { - table := col.AssociateTable - - // TODO: current only support 1 primary key - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(core.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + session.statement.cascadeMode == cascadeAuto) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) + // only 1 PK checked on tag parsing + rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) var err error pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } + if err = session.getByPK(pk, fieldValue); err != nil { + return err } } } @@ -489,35 +469,18 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, if (col.AssociateType == core.AssociateNone && session.statement.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAutoLoad) { - table := col.AssociateTable - if len(table.PrimaryKeys) > 1 { - return errors.New("unsupported composited primary key cascade") - } - - var pk = make(core.PK, len(table.PrimaryKeys)) + session.statement.cascadeMode == cascadeAuto) { + var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + // only 1 PK checked on tag parsing + rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err } - if !isPKZero(pk) { - structInter := reflect.New(fieldType.Elem()) - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist") - } + if err = session.getByPK(pk, fieldValue); err != nil { + return err } } } diff --git a/session_find.go b/session_find.go index f95dcfef..ff39e87b 100644 --- a/session_find.go +++ b/session_find.go @@ -172,9 +172,8 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va var newElemFunc func(fields []string) reflect.Value elemType := containerValue.Type().Elem() - var isPointer bool - if elemType.Kind() == reflect.Ptr { - isPointer = true + var isPointer = elemType.Kind() == reflect.Ptr + if isPointer { elemType = elemType.Elem() } if elemType.Kind() == reflect.Ptr { diff --git a/statement.go b/statement.go index 4f4631ea..de8a071c 100644 --- a/statement.go +++ b/statement.go @@ -37,8 +37,8 @@ type cascadeMode int const ( cascadeCompitable cascadeMode = iota - cascadeAutoLoad - cascadeManuallyLoad + cascadeAuto + cascadeManually ) // Statement save all the sql info for executing SQL From 567b13889bdc855adbcb103488b52aa2b078afc1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 29 May 2017 08:30:20 +0800 Subject: [PATCH 07/12] rename EagerLoad -> EagerGet --- engine.go | 6 +- helpers.go | 29 +++---- helpers_test.go | 21 ----- interface.go | 2 +- session.go | 14 +++- session_associate.go | 77 ++++++++++++++++++- ...gs_to_test.go => session_associate_test.go | 4 +- tag.go | 20 +++++ tag_test.go | 19 +++++ 9 files changed, 141 insertions(+), 51 deletions(-) rename blongs_to_test.go => session_associate_test.go (98%) diff --git a/engine.go b/engine.go index 1c0efc51..18422ce5 100644 --- a/engine.go +++ b/engine.go @@ -1501,11 +1501,11 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } -// EagerLoad loads bean's belongs to tag field immedicatlly -func (engine *Engine) EagerLoad(bean interface{}, cols ...string) error { +// EagerGet loads bean's belongs to tag field immedicatlly +func (engine *Engine) EagerGet(bean interface{}, cols ...string) error { session := engine.NewSession() defer session.Close() - return session.EagerLoad(bean, cols...) + return session.EagerGet(bean, cols...) } // Find retrieve records from table, condiBeans's non-empty fields diff --git a/helpers.go b/helpers.go index b505a995..85240b24 100644 --- a/helpers.go +++ b/helpers.go @@ -104,26 +104,6 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) { return v.Interface(), nil } -func splitTag(tag string) (tags []string) { - tag = strings.TrimSpace(tag) - var hasQuote = false - var lastIdx = 0 - for i, t := range tag { - if t == '\'' { - hasQuote = !hasQuote - } else if t == ' ' { - if lastIdx < i && !hasQuote { - tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) - lastIdx = i + 1 - } - } - } - if lastIdx < len(tag) { - tags = append(tags, strings.TrimSpace(tag[lastIdx:])) - } - return -} - type zeroable interface { IsZero() bool } @@ -479,3 +459,12 @@ func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) return false, false } + +func isStringInSlice(s string, slice []string) bool { + for _, e := range slice { + if s == e { + return true + } + } + return false +} diff --git a/helpers_test.go b/helpers_test.go index d57c54ae..31aec506 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -3,24 +3,3 @@ // license that can be found in the LICENSE file. package xorm - -import "testing" - -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} diff --git a/interface.go b/interface.go index 4adee92e..e15126a7 100644 --- a/interface.go +++ b/interface.go @@ -28,7 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error - EagerLoad(interface{}, ...string) error + EagerGet(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 830011bb..82741b06 100644 --- a/session.go +++ b/session.go @@ -654,8 +654,18 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - if err = session.getByPK(pk, fieldValue); err != nil { - return nil, err + if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type())) + } + if err = session.getByPK(pk, fieldValue); err != nil { + return nil, err + } + } else { + v := fieldValue.Addr() + if err = session.getByPK(pk, &v); err != nil { + return nil, err + } } hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { diff --git a/session_associate.go b/session_associate.go index 8e2c6231..e11be89f 100644 --- a/session_associate.go +++ b/session_associate.go @@ -11,8 +11,77 @@ import ( "github.com/go-xorm/core" ) -// EagerLoad load bean's belongs to tag field immedicatlly -func (session *Session) EagerLoad(bean interface{}, cols ...string) error { +// EagerFind load 's belongs to tag field immedicatlly +func (session *Session) EagerFind(slices interface{}, cols ...string) error { + /*v := reflect.ValueOf(slices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + return errors.New("only slice is supported") + } + + if v.Len() <= 0 { + return nil + } + + vv := v.Index(0) + if vv.Kind() == reflect.Ptr { + vv = vv.Elem() + } + tb, err := session.Engine.autoMapType(vv) + if err != nil { + return err + } + + var pks = make(map[string][]core.PK) + for i := 0; i < v.Len(); i++ { + ev := v.Index(i) + + for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + + if col.AssociateTable != nil { + if col.AssociateType == core.AssociateBelongsTo { + colV, err := col.ValueOfV(&ev) + if err != nil { + return err + } + + pk, err := session.Engine.idOfV(*colV) + if err != nil { + return err + } + var colPtr reflect.Value + if colV.Kind() == reflect.Ptr { + colPtr = *colV + } else { + colPtr = colV.Addr() + } + + if !isZero(pk[0]) { + pks[col.Name] = append(pks[col.Name], pk) + } + } + } + } + } + + for colName, pk := range pks { + slice := reflect.MakeSlice(tp, 0, len(pk)) + err = session.In("", pk).Find(slice.Addr().Interafce()) + if err != nil { + return err + } + + }*/ + return nil +} + +// EagerGet load bean's belongs to tag field immedicatlly +func (session *Session) EagerGet(bean interface{}, cols ...string) error { if session.isAutoClose { defer session.Close() } @@ -24,6 +93,10 @@ func (session *Session) EagerLoad(bean interface{}, cols ...string) error { } for _, col := range tb.Columns() { + if len(cols) > 0 && !isStringInSlice(col.Name, cols) { + continue + } + if col.AssociateTable != nil { if col.AssociateType == core.AssociateBelongsTo { colV, err := col.ValueOfV(&v) diff --git a/blongs_to_test.go b/session_associate_test.go similarity index 98% rename from blongs_to_test.go rename to session_associate_test.go index 8a0cbfcb..021277d9 100644 --- a/blongs_to_test.go +++ b/session_associate_test.go @@ -50,7 +50,7 @@ func TestBelongsTo_Get(t *testing.T) { assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) - err = testEngine.EagerLoad(&cfgNose) + err = testEngine.EagerGet(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) @@ -104,7 +104,7 @@ func TestBelongsTo_GetPtr(t *testing.T) { assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) - err = testEngine.EagerLoad(&cfgNose) + err = testEngine.EagerGet(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) diff --git a/tag.go b/tag.go index 835c65f3..8c7f5b95 100644 --- a/tag.go +++ b/tag.go @@ -35,6 +35,26 @@ type tagContext struct { ignoreNext bool } +func splitTag(tag string) (tags []string) { + tag = strings.TrimSpace(tag) + var hasQuote = false + var lastIdx = 0 + for i, t := range tag { + if t == '\'' { + hasQuote = !hasQuote + } else if t == ' ' { + if lastIdx < i && !hasQuote { + tags = append(tags, strings.TrimSpace(tag[lastIdx:i])) + lastIdx = i + 1 + } + } + } + if lastIdx < len(tag) { + tags = append(tags, strings.TrimSpace(tag[lastIdx:])) + } + return +} + // tagHandler describes tag handler for XORM type tagHandler func(ctx *tagContext) error diff --git a/tag_test.go b/tag_test.go index c9b76048..f97f6a10 100644 --- a/tag_test.go +++ b/tag_test.go @@ -393,3 +393,22 @@ func TestTagTime(t *testing.T) { assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) } + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !sliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +} From 3538ce1752c88a8715f1a3b7fd5a238c93aa0331 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 3 Jul 2017 17:10:32 +0800 Subject: [PATCH 08/12] improve belongs_to --- session.go | 44 +++++++++++++++++++++++++++++++-------- session_associate_test.go | 2 ++ 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/session.go b/session.go index 82741b06..3044f6f3 100644 --- a/session.go +++ b/session.go @@ -17,6 +17,12 @@ import ( "github.com/go-xorm/core" ) +type loadClosure struct { + Func func(core.PK, *reflect.Value) error + pk core.PK + fieldValue *reflect.Value +} + // Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { @@ -658,14 +664,24 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b if fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldValue.Type())) } - if err = session.getByPK(pk, fieldValue); err != nil { - return nil, err - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) } else { v := fieldValue.Addr() - if err = session.getByPK(pk, &v); err != nil { - return nil, err - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: &v, + }) } hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { @@ -696,13 +712,22 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - if err = session.getByPK(pk, fieldValue); err != nil { - return nil, err - } + fmt.Println("=====", fieldValue, fieldValue.IsNil()) + fmt.Printf("%#v", fieldValue) + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true if fieldValue.IsNil() { + // FIXME: find id column structInter := reflect.New(fieldValue.Type().Elem()) fieldValue.Set(structInter) } @@ -871,6 +896,7 @@ func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { if has { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(structInter) + fmt.Println("333", fieldValue.IsNil()) } } else { return errors.New("cascade obj is not exist") diff --git a/session_associate_test.go b/session_associate_test.go index 021277d9..58f83cd1 100644 --- a/session_associate_test.go +++ b/session_associate_test.go @@ -215,6 +215,8 @@ func TestBelongsTo_FindPtr(t *testing.T) { err = testEngine.Cascade().Find(&noses2) assert.NoError(t, err) assert.Equal(t, 2, len(noses2)) + assert.NotNil(t, noses2[0].Face) + assert.NotNil(t, noses2[1].Face) assert.Equal(t, face1.Id, noses2[0].Face.Id) assert.Equal(t, face2.Id, noses2[1].Face.Id) assert.Equal(t, "face1", noses2[0].Face.Name) From 23c6999de8fe5f16a5789917554ecd5c87395290 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 11 Sep 2017 15:43:39 +0800 Subject: [PATCH 09/12] merge EagerGet and EagerFind to Load --- engine.go | 6 ++-- interface.go | 2 +- session.go | 73 ++++++++++++++++++--------------------- session_associate.go | 27 +++++++++++---- session_associate_test.go | 4 +-- session_convert.go | 14 ++++---- statement.go | 9 +++-- tag_extends_test.go | 36 +++++-------------- 8 files changed, 80 insertions(+), 91 deletions(-) diff --git a/engine.go b/engine.go index 18422ce5..732f876e 100644 --- a/engine.go +++ b/engine.go @@ -1501,11 +1501,11 @@ func (engine *Engine) Exist(bean ...interface{}) (bool, error) { return session.Exist(bean...) } -// EagerGet loads bean's belongs to tag field immedicatlly -func (engine *Engine) EagerGet(bean interface{}, cols ...string) error { +// Load loads bean's belongs to tag field immedicatlly +func (engine *Engine) Load(bean interface{}, cols ...string) error { session := engine.NewSession() defer session.Close() - return session.EagerGet(bean, cols...) + return session.Load(bean, cols...) } // Find retrieve records from table, condiBeans's non-empty fields diff --git a/interface.go b/interface.go index e15126a7..ba548b1d 100644 --- a/interface.go +++ b/interface.go @@ -28,7 +28,7 @@ type Interface interface { Delete(interface{}) (int64, error) Distinct(columns ...string) *Session DropIndexes(bean interface{}) error - EagerGet(interface{}, ...string) error + Load(interface{}, ...string) error Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error diff --git a/session.go b/session.go index 3044f6f3..ff312547 100644 --- a/session.go +++ b/session.go @@ -21,6 +21,7 @@ type loadClosure struct { Func func(core.PK, *reflect.Value) error pk core.PK fieldValue *reflect.Value + loaded bool } // Session keep a pointer to sql.DB and provides all execution of all @@ -57,6 +58,9 @@ type Session struct { lastSQL string lastSQLArgs []interface{} + cascadeMode cascadeMode + cascadeLevel int // load level + err error } @@ -88,6 +92,9 @@ func (session *Session) Init() { session.lastSQL = "" session.lastSQLArgs = []interface{}{} + + session.cascadeMode = cascadeCompitable + session.cascadeLevel = 2 } // Close release the connection from pool @@ -155,7 +162,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.cascadeMode = cascadeManually + session.cascadeMode = cascadeLazy return session } @@ -210,16 +217,16 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { - var mode = cascadeAuto + var mode = cascadeEager if len(trueOrFalse) >= 1 { if trueOrFalse[0] { - mode = cascadeAuto + mode = cascadeEager } else { - mode = cascadeManually + mode = cascadeLazy } } - session.statement.cascadeMode = mode + session.cascadeMode = mode return session } @@ -642,10 +649,10 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - } else if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + } else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error rawValueType := col.AssociateTable.PKColumns()[0].FieldType @@ -660,29 +667,15 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - if fieldValue.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - fieldValue.Set(reflect.New(fieldValue.Type())) - } - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(session *Session, bean interface{}) error { - fieldValue := bean.(*reflect.Value) - return session.getByPK(pk, fieldValue) - }, - session: session, - bean: fieldValue, - }) - } else { - v := fieldValue.Addr() - session.afterProcessors = append(session.afterProcessors, executedProcessor{ - fun: func(session *Session, bean interface{}) error { - fieldValue := bean.(*reflect.Value) - return session.getByPK(pk, fieldValue) - }, - session: session, - bean: &v, - }) - } + session.afterProcessors = append(session.afterProcessors, executedProcessor{ + fun: func(session *Session, bean interface{}) error { + fieldValue := bean.(*reflect.Value) + return session.getByPK(pk, fieldValue) + }, + session: session, + bean: fieldValue, + }) + session.cascadeLevel-- hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true @@ -694,10 +687,10 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } case reflect.Ptr: if fieldType != core.PtrTimeType && fieldType.Elem().Kind() == reflect.Struct { - if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) @@ -712,8 +705,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } pk[0] = reflect.ValueOf(pk[0]).Elem().Interface() - fmt.Println("=====", fieldValue, fieldValue.IsNil()) - fmt.Printf("%#v", fieldValue) session.afterProcessors = append(session.afterProcessors, executedProcessor{ fun: func(session *Session, bean interface{}) error { fieldValue := bean.(*reflect.Value) @@ -723,6 +714,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b bean: fieldValue, }) + session.cascadeLevel-- hasAssigned = true } else if col.AssociateType == core.AssociateBelongsTo { hasAssigned = true @@ -732,7 +724,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue.Set(structInter) } - //fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Set(vv) err := convertAssign(fieldValue.Elem().FieldByName(table.PKColumns()[0].FieldName).Addr().Interface(), vv.Interface()) if err != nil { @@ -741,7 +732,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } else { // !nashtsai! TODO merge duplicated codes above - //typeStr := fieldType.String() switch fieldType { // following types case matching ptr's native type, therefore assign ptr directly case core.PtrStringType: @@ -889,14 +879,17 @@ func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { structInter = fieldValue.Addr() } - has, err := session.ID(pk).NoCascade().get(structInter.Interface()) + has, err := session.ID(pk).NoAutoCondition().get(structInter.Interface()) if err != nil { return err } if has { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(structInter) - fmt.Println("333", fieldValue.IsNil()) + } else if fieldValue.Kind() == reflect.Struct { + fieldValue.Set(structInter.Elem()) + } else { + return errors.New("set value failed") } } else { return errors.New("cascade obj is not exist") diff --git a/session_associate.go b/session_associate.go index e11be89f..aa922db2 100644 --- a/session_associate.go +++ b/session_associate.go @@ -11,8 +11,22 @@ import ( "github.com/go-xorm/core" ) -// EagerFind load 's belongs to tag field immedicatlly -func (session *Session) EagerFind(slices interface{}, cols ...string) error { +// Load loads associated fields from database +func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { + v := reflect.ValueOf(beanOrSlices) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() == reflect.Slice { + return session.loadFind(beanOrSlices, cols...) + } else if v.Kind() == reflect.Struct { + return session.loadGet(beanOrSlices, cols...) + } + return errors.New("unsupported load type, must struct or slice") +} + +// loadFind load 's belongs to tag field immedicatlly +func (session *Session) loadFind(slices interface{}, cols ...string) error { /*v := reflect.ValueOf(slices) if v.Kind() == reflect.Ptr { v = v.Elem() @@ -80,8 +94,8 @@ func (session *Session) EagerFind(slices interface{}, cols ...string) error { return nil } -// EagerGet load bean's belongs to tag field immedicatlly -func (session *Session) EagerGet(bean interface{}, cols ...string) error { +// loadGet load bean's belongs to tag field immedicatlly +func (session *Session) loadGet(bean interface{}, cols ...string) error { if session.isAutoClose { defer session.Close() } @@ -115,14 +129,15 @@ func (session *Session) EagerGet(bean interface{}, cols ...string) error { colPtr = colV.Addr() } - if !isZero(pk[0]) { - has, err := session.ID(pk).get(colPtr.Interface()) + if !isZero(pk[0]) && session.cascadeLevel > 0 { + has, err := session.ID(pk).NoAutoCondition().get(colPtr.Interface()) if err != nil { return err } if !has { return errors.New("load bean does not exist") } + session.cascadeLevel-- } } } diff --git a/session_associate_test.go b/session_associate_test.go index 58f83cd1..62e46cbc 100644 --- a/session_associate_test.go +++ b/session_associate_test.go @@ -50,7 +50,7 @@ func TestBelongsTo_Get(t *testing.T) { assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "", cfgNose.Face.Name) - err = testEngine.EagerGet(&cfgNose) + err = testEngine.Load(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) @@ -104,7 +104,7 @@ func TestBelongsTo_GetPtr(t *testing.T) { assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) - err = testEngine.EagerGet(&cfgNose) + err = testEngine.Load(&cfgNose) assert.NoError(t, err) assert.Equal(t, nose.Id, cfgNose.Id) assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) diff --git a/session_convert.go b/session_convert.go index e69788b8..de7cfbec 100644 --- a/session_convert.go +++ b/session_convert.go @@ -202,10 +202,10 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + } else if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) // only 1 PK checked on tag parsing rawValueType := col.AssociateTable.ColumnType(col.AssociateTable.PKColumns()[0].FieldName) @@ -215,6 +215,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return err } + session.cascadeLevel-- if err = session.getByPK(pk, fieldValue); err != nil { return err } @@ -466,10 +467,10 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(&x)) default: - if (col.AssociateType == core.AssociateNone && - session.statement.cascadeMode == cascadeCompitable) || + if session.cascadeLevel > 0 && ((col.AssociateType == core.AssociateNone && + session.cascadeMode == cascadeCompitable) || (col.AssociateType == core.AssociateBelongsTo && - session.statement.cascadeMode == cascadeAuto) { + session.cascadeMode == cascadeEager)) { var pk = make(core.PK, len(col.AssociateTable.PrimaryKeys)) var err error // only 1 PK checked on tag parsing @@ -479,6 +480,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return err } + session.cascadeLevel-- if err = session.getByPK(pk, fieldValue); err != nil { return err } diff --git a/statement.go b/statement.go index de8a071c..b2d1e90b 100644 --- a/statement.go +++ b/statement.go @@ -36,9 +36,10 @@ type exprParam struct { type cascadeMode int const ( - cascadeCompitable cascadeMode = iota - cascadeAuto - cascadeManually + cascadeCompitable cascadeMode = iota // load field beans with another SQL with no + cascadeEager // load field beans with another SQL + cascadeJoin // load field beans with join + cascadeLazy // don't load anything ) // Statement save all the sql info for executing SQL @@ -62,7 +63,6 @@ type Statement struct { tableName string RawSQL string RawParams []interface{} - cascadeMode cascadeMode UseAutoJoin bool StoreEngine string Charset string @@ -90,7 +90,6 @@ func (statement *Statement) Init() { statement.Start = 0 statement.LimitN = 0 statement.OrderStr = "" - statement.cascadeMode = cascadeCompitable statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" diff --git a/tag_extends_test.go b/tag_extends_test.go index b70eefe3..6c72a6dc 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -169,22 +169,12 @@ func TestExtends(t *testing.T) { tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - if err != nil { - t.Error(err) - panic(err) - } + assert.Error(t, err) users := make([]tempUser3, 0) err = testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - err = errors.New("error get data not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, len(users)) assertSync(t, new(Userinfo), new(Userdetail)) @@ -209,21 +199,11 @@ func TestExtends(t *testing.T) { sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s", qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid)) b, err := testEngine.SQL(sql).NoCascade().Get(&info) - if err != nil { - t.Error(err) - panic(err) - } - if !b { - err = errors.New("should has lest one record") - t.Error(err) - panic(err) - } - fmt.Println(info) - if info.Userinfo.Uid == 0 || info.Userdetail.Id == 0 { - err = errors.New("all of the id should has value") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, b) + assert.False(t, info.Userinfo.Uid == 0 || info.Userdetail.Id == 0) + + panic("") fmt.Println("----join--info2") var info2 UserAndDetail From 1d5bc623f33b3d81e0c398fbaec3ca3ac0763faa Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 12 Sep 2017 21:43:43 +0800 Subject: [PATCH 10/12] improve some functions --- session.go | 2 ++ session_associate_test.go | 2 ++ session_find.go | 2 ++ 3 files changed, 6 insertions(+) diff --git a/session.go b/session.go index ff312547..3f9656ee 100644 --- a/session.go +++ b/session.go @@ -886,8 +886,10 @@ func (session *Session) getByPK(pk core.PK, fieldValue *reflect.Value) error { if has { if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { fieldValue.Set(structInter) + fmt.Println("getByPK value ptr:", fieldValue.Interface()) } else if fieldValue.Kind() == reflect.Struct { fieldValue.Set(structInter.Elem()) + fmt.Println("getByPK value:", fieldValue.Interface()) } else { return errors.New("set value failed") } diff --git a/session_associate_test.go b/session_associate_test.go index 62e46cbc..ae2b0ee0 100644 --- a/session_associate_test.go +++ b/session_associate_test.go @@ -5,6 +5,7 @@ package xorm import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -164,6 +165,7 @@ func TestBelongsTo_Find(t *testing.T) { err = testEngine.Cascade().Find(&noses2) assert.NoError(t, err) assert.Equal(t, 2, len(noses2)) + fmt.Println("noses:", noses2) assert.Equal(t, face1.Id, noses2[0].Face.Id) assert.Equal(t, face2.Id, noses2[1].Face.Id) assert.Equal(t, "face1", noses2[0].Face.Name) diff --git a/session_find.go b/session_find.go index ff39e87b..f1fc3a84 100644 --- a/session_find.go +++ b/session_find.go @@ -203,7 +203,9 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va if isPointer { containerValue.Set(reflect.Append(containerValue, newValue.Elem().Addr())) } else { + fmt.Println("---", newValue.Elem()) containerValue.Set(reflect.Append(containerValue, newValue.Elem())) + fmt.Println("===", containerValue.Interface()) } return nil } From 2b06f05d40c8a2523e8a5eed45e27bb9ae9a323a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Oct 2017 23:54:10 +0800 Subject: [PATCH 11/12] fix extends tag --- tag.go | 7 +++--- tag_extends_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 6 deletions(-) diff --git a/tag.go b/tag.go index 8c7f5b95..40405cc7 100644 --- a/tag.go +++ b/tag.go @@ -286,11 +286,12 @@ func ExtendsTagHandler(ctx *tagContext) error { if err != nil { return err } - for _, col := range parentTable.Columns() { + for _, oriCol := range parentTable.Columns() { + col := *oriCol col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) - ctx.table.AddColumn(col) + ctx.table.AddColumn(&col) for indexName, indexType := range col.Indexes { - addIndex(indexName, ctx.table, col, indexType) + addIndex(indexName, ctx.table, &col, indexType) } } default: diff --git a/tag_extends_test.go b/tag_extends_test.go index 6c72a6dc..f0af21d7 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -169,7 +169,7 @@ func TestExtends(t *testing.T) { tu6 := &tempUser3{&tempUser{0, "extends update"}, ""} _, err = testEngine.ID(tu5.Temp.Id).Update(tu6) - assert.Error(t, err) + assert.NoError(t, err) users := make([]tempUser3, 0) err = testEngine.Find(&users) @@ -203,8 +203,6 @@ func TestExtends(t *testing.T) { assert.True(t, b) assert.False(t, info.Userinfo.Uid == 0 || info.Userdetail.Id == 0) - panic("") - fmt.Println("----join--info2") var info2 UserAndDetail b, err = testEngine.Table(&Userinfo{}). @@ -516,3 +514,55 @@ func TestExtends4(t *testing.T) { panic(err) } } + +func TestExtendsTag(t *testing.T) { + assert.NoError(t, prepareEngine()) + + table := testEngine.TableInfo(new(Userdetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 3, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[1]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[2]) + + table = testEngine.TableInfo(new(Userinfo)) + assert.NotNil(t, table) + assert.EqualValues(t, 8, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + + table = testEngine.TableInfo(new(UserAndDetail)) + assert.NotNil(t, table) + assert.EqualValues(t, 11, len(table.ColumnsSeq())) + assert.EqualValues(t, "id", table.ColumnsSeq()[0]) + assert.EqualValues(t, "username", table.ColumnsSeq()[1]) + assert.EqualValues(t, "departname", table.ColumnsSeq()[2]) + assert.EqualValues(t, "created", table.ColumnsSeq()[3]) + assert.EqualValues(t, "detail_id", table.ColumnsSeq()[4]) + assert.EqualValues(t, "height", table.ColumnsSeq()[5]) + assert.EqualValues(t, "avatar", table.ColumnsSeq()[6]) + assert.EqualValues(t, "is_man", table.ColumnsSeq()[7]) + assert.EqualValues(t, "id", table.ColumnsSeq()[8]) + assert.EqualValues(t, "intro", table.ColumnsSeq()[9]) + assert.EqualValues(t, "profile", table.ColumnsSeq()[10]) + + cols := table.Columns() + assert.EqualValues(t, 11, len(cols)) + assert.EqualValues(t, "Userinfo.Uid", cols[0].FieldName) + assert.EqualValues(t, "Userinfo.Username", cols[1].FieldName) + assert.EqualValues(t, "Userinfo.Departname", cols[2].FieldName) + assert.EqualValues(t, "Userinfo.Created", cols[3].FieldName) + assert.EqualValues(t, "Userinfo.Detail", cols[4].FieldName) + assert.EqualValues(t, "Userinfo.Height", cols[5].FieldName) + assert.EqualValues(t, "Userinfo.Avatar", cols[6].FieldName) + assert.EqualValues(t, "Userinfo.IsMan", cols[7].FieldName) + assert.EqualValues(t, "Userdetail.Id", cols[8].FieldName) + assert.EqualValues(t, "Userdetail.Intro", cols[9].FieldName) + assert.EqualValues(t, "Userdetail.Profile", cols[10].FieldName) +} From b8373f09d7e775f238cc7ae097540ce8543791cd Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 31 Oct 2017 16:51:22 +0800 Subject: [PATCH 12/12] improve Load find --- session_associate.go | 23 ++++++----- session_associate_test.go | 82 +++++++++++++++++++++------------------ 2 files changed, 55 insertions(+), 50 deletions(-) diff --git a/session_associate.go b/session_associate.go index aa922db2..62e924a3 100644 --- a/session_associate.go +++ b/session_associate.go @@ -27,7 +27,7 @@ func (session *Session) Load(beanOrSlices interface{}, cols ...string) error { // loadFind load 's belongs to tag field immedicatlly func (session *Session) loadFind(slices interface{}, cols ...string) error { - /*v := reflect.ValueOf(slices) + v := reflect.ValueOf(slices) if v.Kind() == reflect.Ptr { v = v.Elem() } @@ -43,12 +43,12 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { if vv.Kind() == reflect.Ptr { vv = vv.Elem() } - tb, err := session.Engine.autoMapType(vv) + tb, err := session.engine.autoMapType(vv) if err != nil { return err } - var pks = make(map[string][]core.PK) + var pks = make(map[*core.Column][]interface{}) for i := 0; i < v.Len(); i++ { ev := v.Index(i) @@ -64,33 +64,32 @@ func (session *Session) loadFind(slices interface{}, cols ...string) error { return err } - pk, err := session.Engine.idOfV(*colV) + pk, err := session.engine.idOfV(*colV) if err != nil { return err } - var colPtr reflect.Value + /*var colPtr reflect.Value if colV.Kind() == reflect.Ptr { colPtr = *colV } else { colPtr = colV.Addr() - } + }*/ if !isZero(pk[0]) { - pks[col.Name] = append(pks[col.Name], pk) + pks[col] = append(pks[col], pk[0]) } } } } } - for colName, pk := range pks { - slice := reflect.MakeSlice(tp, 0, len(pk)) - err = session.In("", pk).Find(slice.Addr().Interafce()) + for col, pk := range pks { + slice := reflect.MakeSlice(col.FieldType, 0, len(pk)) + err = session.In(col.Name, pk...).find(slice.Addr().Interface()) if err != nil { return err } - - }*/ + } return nil } diff --git a/session_associate_test.go b/session_associate_test.go index ae2b0ee0..94624216 100644 --- a/session_associate_test.go +++ b/session_associate_test.go @@ -5,7 +5,6 @@ package xorm import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -14,36 +13,36 @@ import ( func TestBelongsTo_Get(t *testing.T) { assert.NoError(t, prepareEngine()) - type Face struct { + type Face1 struct { Id int64 Name string } - type Nose struct { + type Nose1 struct { Id int64 - Face Face `xorm:"belongs_to"` + Face Face1 `xorm:"belongs_to"` } - err := testEngine.Sync2(new(Nose), new(Face)) + err := testEngine.Sync2(new(Nose1), new(Face1)) assert.NoError(t, err) - var face = Face{ + var face = Face1{ Name: "face1", } _, err = testEngine.Insert(&face) assert.NoError(t, err) - var cfgFace Face + var cfgFace Face1 has, err := testEngine.Get(&cfgFace) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, face, cfgFace) - var nose = Nose{Face: face} + var nose = Nose1{Face: face} _, err = testEngine.Insert(&nose) assert.NoError(t, err) - var cfgNose Nose + var cfgNose Nose1 has, err = testEngine.Get(&cfgNose) assert.NoError(t, err) assert.Equal(t, true, has) @@ -57,7 +56,7 @@ func TestBelongsTo_Get(t *testing.T) { assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "face1", cfgNose.Face.Name) - var cfgNose2 Nose + var cfgNose2 Nose1 has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) @@ -69,36 +68,36 @@ func TestBelongsTo_Get(t *testing.T) { func TestBelongsTo_GetPtr(t *testing.T) { assert.NoError(t, prepareEngine()) - type Face struct { + type Face2 struct { Id int64 Name string } - type Nose struct { + type Nose2 struct { Id int64 - Face *Face `xorm:"belongs_to"` + Face *Face2 `xorm:"belongs_to"` } - err := testEngine.Sync2(new(Nose), new(Face)) + err := testEngine.Sync2(new(Nose2), new(Face2)) assert.NoError(t, err) - var face = Face{ + var face = Face2{ Name: "face1", } _, err = testEngine.Insert(&face) assert.NoError(t, err) - var cfgFace Face + var cfgFace Face2 has, err := testEngine.Get(&cfgFace) assert.NoError(t, err) assert.Equal(t, true, has) assert.Equal(t, face, cfgFace) - var nose = Nose{Face: &face} + var nose = Nose2{Face: &face} _, err = testEngine.Insert(&nose) assert.NoError(t, err) - var cfgNose Nose + var cfgNose Nose2 has, err = testEngine.Get(&cfgNose) assert.NoError(t, err) assert.Equal(t, true, has) @@ -111,7 +110,7 @@ func TestBelongsTo_GetPtr(t *testing.T) { assert.Equal(t, nose.Face.Id, cfgNose.Face.Id) assert.Equal(t, "face1", cfgNose.Face.Name) - var cfgNose2 Nose + var cfgNose2 Nose2 has, err = testEngine.Cascade().Get(&cfgNose2) assert.NoError(t, err) assert.Equal(t, true, has) @@ -123,36 +122,36 @@ func TestBelongsTo_GetPtr(t *testing.T) { func TestBelongsTo_Find(t *testing.T) { assert.NoError(t, prepareEngine()) - type Face struct { + type Face3 struct { Id int64 Name string } - type Nose struct { + type Nose3 struct { Id int64 - Face Face `xorm:"belongs_to"` + Face Face3 `xorm:"belongs_to"` } - err := testEngine.Sync2(new(Nose), new(Face)) + err := testEngine.Sync2(new(Nose3), new(Face3)) assert.NoError(t, err) - var face1 = Face{ + var face1 = Face3{ Name: "face1", } - var face2 = Face{ + var face2 = Face3{ Name: "face2", } _, err = testEngine.Insert(&face1, &face2) assert.NoError(t, err) - var noses = []Nose{ + var noses = []Nose3{ {Face: face1}, {Face: face2}, } _, err = testEngine.Insert(&noses) assert.NoError(t, err) - var noses1 []Nose + var noses1 []Nose3 err = testEngine.Find(&noses1) assert.NoError(t, err) assert.Equal(t, 2, len(noses1)) @@ -161,50 +160,54 @@ func TestBelongsTo_Find(t *testing.T) { assert.Equal(t, "", noses1[0].Face.Name) assert.Equal(t, "", noses1[1].Face.Name) - var noses2 []Nose + var noses2 []Nose3 err = testEngine.Cascade().Find(&noses2) assert.NoError(t, err) assert.Equal(t, 2, len(noses2)) - fmt.Println("noses:", noses2) assert.Equal(t, face1.Id, noses2[0].Face.Id) assert.Equal(t, face2.Id, noses2[1].Face.Id) assert.Equal(t, "face1", noses2[0].Face.Name) assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses1, "face") + assert.NoError(t, err) + assert.Equal(t, "face1", noses1[0].Face.Name) + assert.Equal(t, "face2", noses1[1].Face.Name) } func TestBelongsTo_FindPtr(t *testing.T) { assert.NoError(t, prepareEngine()) - type Face struct { + type Face4 struct { Id int64 Name string } - type Nose struct { + type Nose4 struct { Id int64 - Face *Face `xorm:"belongs_to"` + Face *Face4 `xorm:"belongs_to"` } - err := testEngine.Sync2(new(Nose), new(Face)) + err := testEngine.Sync2(new(Nose4), new(Face4)) assert.NoError(t, err) - var face1 = Face{ + var face1 = Face4{ Name: "face1", } - var face2 = Face{ + var face2 = Face4{ Name: "face2", } _, err = testEngine.Insert(&face1, &face2) assert.NoError(t, err) - var noses = []Nose{ + var noses = []Nose4{ {Face: &face1}, {Face: &face2}, } _, err = testEngine.Insert(&noses) assert.NoError(t, err) - var noses1 []Nose + var noses1 []Nose4 err = testEngine.Find(&noses1) assert.NoError(t, err) assert.Equal(t, 2, len(noses1)) @@ -213,7 +216,7 @@ func TestBelongsTo_FindPtr(t *testing.T) { assert.Equal(t, "", noses1[0].Face.Name) assert.Equal(t, "", noses1[1].Face.Name) - var noses2 []Nose + var noses2 []Nose4 err = testEngine.Cascade().Find(&noses2) assert.NoError(t, err) assert.Equal(t, 2, len(noses2)) @@ -223,4 +226,7 @@ func TestBelongsTo_FindPtr(t *testing.T) { assert.Equal(t, face2.Id, noses2[1].Face.Id) assert.Equal(t, "face1", noses2[0].Face.Name) assert.Equal(t, "face2", noses2[1].Face.Name) + + err = testEngine.Load(noses2, "face") + assert.NoError(t, err) }