From 557d5a4101df9e1f7eeffc3696a43ccf2d378514 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 14 Mar 2017 22:25:10 +0800 Subject: [PATCH] implement simple belongs_to tag --- blongs_to_test.go | 149 ++++++++++++++++++++++++++++++++++++++++++++ dialect_postgres.go | 2 +- engine.go | 101 ++++++++++++++++++++++++------ helpers.go | 8 +++ session.go | 26 +++++--- session_convert.go | 27 ++++---- session_schema.go | 2 +- statement.go | 12 +++- tag.go | 105 ++++++++++++++++++++++--------- xorm_test.go | 3 + 10 files changed, 359 insertions(+), 76 deletions(-) create mode 100644 blongs_to_test.go diff --git a/blongs_to_test.go b/blongs_to_test.go new file mode 100644 index 00000000..b6d8b1b5 --- /dev/null +++ b/blongs_to_test.go @@ -0,0 +1,149 @@ +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBelongsTo_Get(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face = Face{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgFace, face) + + var nose = Nose{Face: face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose.Id, nose.Id) + // FIXME: the id should be set back to the field + //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + assert.Equal(t, "", cfgNose.Face.Name) + + var cfgNose2 Nose + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose2.Id, nose.Id) + assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_GetPtr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face *Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face = Face{ + Name: "face1", + } + _, err = testEngine.Insert(&face) + assert.NoError(t, err) + + var cfgFace Face + has, err := testEngine.Get(&cfgFace) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgFace, face) + + var nose = Nose{Face: &face} + _, err = testEngine.Insert(&nose) + assert.NoError(t, err) + + var cfgNose Nose + has, err = testEngine.Get(&cfgNose) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose.Id, nose.Id) + // FIXME: the id should be set back to the field + //assert.Equal(t, cfgNose.Face.Id, nose.Face.Id) + + var cfgNose2 Nose + has, err = testEngine.Cascade().Get(&cfgNose2) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, cfgNose2.Id, nose.Id) + assert.Equal(t, cfgNose2.Face.Id, nose.Face.Id) + assert.Equal(t, "face1", cfgNose2.Face.Name) +} + +func TestBelongsTo_Find(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type Face struct { + Id int64 + Name string + } + + type Nose struct { + Id int64 + Face Face `xorm:"belongs_to"` + } + + err := testEngine.Sync2(new(Nose), new(Face)) + assert.NoError(t, err) + + var face1 = Face{ + Name: "face1", + } + var face2 = Face{ + Name: "face2", + } + _, err = testEngine.Insert(&face1, &face2) + assert.NoError(t, err) + + var noses = []Nose{ + {Face: face1}, + {Face: face2}, + } + _, err = testEngine.Insert(&noses) + assert.NoError(t, err) + + var noses1 []Nose + err = testEngine.Find(&noses1) + assert.NoError(t, err) + assert.Equal(t, 2, len(noses1)) + // FIXME: + //assert.Equal(t, face1.Id, noses1[0].Face.Id) + //assert.Equal(t, face2.Id, noses1[1].Face.Id) + assert.Equal(t, "", noses1[0].Face.Name) + assert.Equal(t, "", noses1[1].Face.Name) +} diff --git a/dialect_postgres.go b/dialect_postgres.go index 3f5c526f..0f9679c7 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -811,7 +811,7 @@ func (db *postgres) SqlType(c *core.Column) string { case core.NVarchar: res = core.Varchar case core.Uuid: - res = core.Uuid + return core.Uuid case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: return core.Bytea case core.Double: diff --git a/engine.go b/engine.go index 2b986966..5ca568e1 100644 --- a/engine.go +++ b/engine.go @@ -9,7 +9,6 @@ import ( "bytes" "database/sql" "encoding/gob" - "errors" "fmt" "io" "os" @@ -21,6 +20,7 @@ import ( "github.com/go-xorm/builder" "github.com/go-xorm/core" + "github.com/pkg/errors" ) // Engine is the major struct of xorm, it means a database manager. @@ -792,13 +792,18 @@ func (engine *Engine) UnMapType(t reflect.Type) { } func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) { - t := v.Type() engine.mutex.Lock() defer engine.mutex.Unlock() + return engine.autoMapTypeNoLock(v) +} + +func (engine *Engine) autoMapTypeNoLock(v reflect.Value) (*core.Table, error) { + t := v.Type() table, ok := engine.Tables[t] if !ok { var err error - table, err = engine.mapType(v) + var parsingTables = make(map[reflect.Type]*core.Table) + table, err = engine.mapType(parsingTables, v) if err != nil { return nil, err } @@ -872,9 +877,17 @@ var ( tpTableName = reflect.TypeOf((*TableName)(nil)).Elem() ) -func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { +func (engine *Engine) mapType(parsingTables map[reflect.Type]*core.Table, v reflect.Value) (*core.Table, error) { + if v.Kind() != reflect.Struct { + return nil, errors.New("need a struct to map") + } t := v.Type() + if table, ok := parsingTables[t]; ok { + return table, nil + } + table := engine.newTable() + parsingTables[t] = table if tb, ok := v.Interface().(TableName); ok { table.Name = tb.TableName() } else { @@ -901,24 +914,37 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { fieldValue := v.Field(i) fieldType := fieldValue.Type() - if ormTagStr != "" { - col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]int)} - tags := splitTag(ormTagStr) + var ctx = tagContext{ + engine: engine, + parsingTables: parsingTables, + table: table, + hasCacheTag: false, + hasNoCacheTag: false, + + fieldValue: fieldValue, + indexNames: make(map[string]int), + isIndex: false, + isUnique: false, + } + + if ormTagStr != "" { + col = &core.Column{ + FieldName: t.Field(i).Name, + Nullable: true, + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: core.TWOSIDES, + Indexes: make(map[string]int), + } + ctx.col = col + + tags := splitTag(ormTagStr) if len(tags) > 0 { if tags[0] == "-" { continue } - var ctx = tagContext{ - table: table, - col: col, - fieldValue: fieldValue, - indexNames: make(map[string]int), - engine: engine, - } - if strings.ToUpper(tags[0]) == "EXTENDS" { if err := ExtendsTagHandler(&ctx); err != nil { return nil, err @@ -1014,20 +1040,55 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { } else { sqlType = core.Type2SQLType(fieldType) } - col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), - t.Field(i).Name, sqlType, sqlType.DefaultLength, - sqlType.DefaultLength2, true) + + col = core.NewColumn( + engine.ColumnMapper.Obj2Table(t.Field(i).Name), + t.Field(i).Name, + sqlType, + sqlType.DefaultLength, + sqlType.DefaultLength2, + true, + ) if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) { idFieldColName = col.Name } + + ctx.col = col } if col.IsAutoIncrement { col.Nullable = false } - table.AddColumn(col) + var tp = fieldType + if isPtrStruct(fieldType) { + tp = fieldType.Elem() + } + if isStruct(tp) && col.AssociateTable == nil { + var isBelongsTo = !(tp.ConvertibleTo(core.TimeType) || col.SQLType.IsJson()) + if _, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Interface().(sql.Scanner); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + isBelongsTo = false + } + if _, ok := fieldValue.Interface().(core.Conversion); ok { + isBelongsTo = false + } + if isBelongsTo { + err := BelongsToTagHandler(&ctx) + if err != nil { + return nil, err + } + col.AssociateType = core.AssociateNone + } + } + + table.AddColumn(col) } // end for if idFieldColName != "" && len(table.PrimaryKeys) == 0 { diff --git a/helpers.go b/helpers.go index f39ed472..b505a995 100644 --- a/helpers.go +++ b/helpers.go @@ -16,6 +16,14 @@ import ( "github.com/go-xorm/core" ) +func isStruct(t reflect.Type) bool { + return t.Kind() == reflect.Struct || isPtrStruct(t) +} + +func isPtrStruct(t reflect.Type) bool { + return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct +} + // str2PK convert string value to primary key value according to tp func str2PKValue(s string, tp reflect.Type) (reflect.Value, error) { var err error diff --git a/session.go b/session.go index 5c6cb5f9..789ec35c 100644 --- a/session.go +++ b/session.go @@ -149,7 +149,7 @@ func (session *Session) Alias(alias string) *Session { // NoCascade indicate that no cascade load child object func (session *Session) NoCascade() *Session { - session.statement.UseCascade = false + session.statement.cascadeMode = cascadeManuallyLoad return session } @@ -204,9 +204,16 @@ func (session *Session) Charset(charset string) *Session { // Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { + var mode = cascadeAutoLoad if len(trueOrFalse) >= 1 { - session.statement.UseCascade = trueOrFalse[0] + if trueOrFalse[0] { + mode = cascadeAutoLoad + } else { + mode = cascadeManuallyLoad + } } + + session.statement.cascadeMode = mode return session } @@ -629,7 +636,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b session.engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - } else if col.SQLType.IsJson() { + /*} else if col.SQLType.IsJson() { if rawValueType.Kind() == reflect.String { hasAssigned = true x := reflect.New(fieldType) @@ -650,18 +657,19 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } fieldValue.Set(x.Elem()) } - } - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return nil, err - } + }*/ + } else if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable hasAssigned = true if len(table.PrimaryKeys) != 1 { return nil, errors.New("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) + var err error pk[0], err = asKind(vv, rawValueType) if err != nil { return nil, err diff --git a/session_convert.go b/session_convert.go index f2c949ba..1e6355c3 100644 --- a/session_convert.go +++ b/session_convert.go @@ -203,11 +203,11 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.statement.UseCascade { - table, err := session.engine.autoMapType(*fieldValue) - if err != nil { - return err - } + } else if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable // TODO: current only support 1 primary key if len(table.PrimaryKeys) > 1 { @@ -216,6 +216,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, var pk = make(core.PK, len(table.PrimaryKeys)) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + var err error pk[0], err = str2PK(string(data), rawValueType) if err != nil { return err @@ -485,18 +486,17 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, v = x fieldValue.Set(reflect.ValueOf(&x)) default: - if session.statement.UseCascade { - structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.autoMapType(structInter.Elem()) - if err != nil { - return err - } - + if (col.AssociateType == core.AssociateNone && + session.statement.cascadeMode == cascadeCompitable) || + (col.AssociateType == core.AssociateBelongsTo && + session.statement.cascadeMode == cascadeAutoLoad) { + table := col.AssociateTable if len(table.PrimaryKeys) > 1 { return errors.New("unsupported composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) + var err error rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) pk[0], err = str2PK(string(data), rawValueType) if err != nil { @@ -504,6 +504,7 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } if !isPKZero(pk) { + structInter := reflect.New(fieldType.Elem()) // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // property to be fetched lazily @@ -518,8 +519,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, return errors.New("cascade obj is not exist") } } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } default: diff --git a/session_schema.go b/session_schema.go index a2708b73..8960cdfe 100644 --- a/session_schema.go +++ b/session_schema.go @@ -259,7 +259,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - table, err := engine.mapType(v) + table, err := engine.autoMapType(v) if err != nil { return err } diff --git a/statement.go b/statement.go index 23346c71..4f4631ea 100644 --- a/statement.go +++ b/statement.go @@ -33,6 +33,14 @@ type exprParam struct { expr string } +type cascadeMode int + +const ( + cascadeCompitable cascadeMode = iota + cascadeAutoLoad + cascadeManuallyLoad +) + // Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -54,7 +62,7 @@ type Statement struct { tableName string RawSQL string RawParams []interface{} - UseCascade bool + cascadeMode cascadeMode UseAutoJoin bool StoreEngine string Charset string @@ -82,7 +90,7 @@ func (statement *Statement) Init() { statement.Start = 0 statement.LimitN = 0 statement.OrderStr = "" - statement.UseCascade = true + statement.cascadeMode = cascadeCompitable statement.JoinStr = "" statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" diff --git a/tag.go b/tag.go index e1c821fb..835c65f3 100644 --- a/tag.go +++ b/tag.go @@ -5,6 +5,7 @@ package xorm import ( + "errors" "fmt" "reflect" "strconv" @@ -15,18 +16,22 @@ import ( ) type tagContext struct { + engine *Engine + parsingTables map[reflect.Type]*core.Table + + table *core.Table + hasCacheTag bool + hasNoCacheTag bool + + col *core.Column + fieldValue reflect.Value + isIndex bool + isUnique bool + indexNames map[string]int + tagName string params []string preTag, nextTag string - table *core.Table - col *core.Column - fieldValue reflect.Value - isIndex bool - isUnique bool - indexNames map[string]int - engine *Engine - hasCacheTag bool - hasNoCacheTag bool ignoreNext bool } @@ -36,25 +41,26 @@ type tagHandler func(ctx *tagContext) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]tagHandler{ - "<-": OnlyFromDBTagHandler, - "->": OnlyToDBTagHandler, - "PK": PKTagHandler, - "NULL": NULLTagHandler, - "NOT": IgnoreTagHandler, - "AUTOINCR": AutoIncrTagHandler, - "DEFAULT": DefaultTagHandler, - "CREATED": CreatedTagHandler, - "UPDATED": UpdatedTagHandler, - "DELETED": DeletedTagHandler, - "VERSION": VersionTagHandler, - "UTC": UTCTagHandler, - "LOCAL": LocalTagHandler, - "NOTNULL": NotNullTagHandler, - "INDEX": IndexTagHandler, - "UNIQUE": UniqueTagHandler, - "CACHE": CacheTagHandler, - "NOCACHE": NoCacheTagHandler, - "COMMENT": CommentTagHandler, + "<-": OnlyFromDBTagHandler, + "->": OnlyToDBTagHandler, + "PK": PKTagHandler, + "NULL": NULLTagHandler, + "NOT": IgnoreTagHandler, + "AUTOINCR": AutoIncrTagHandler, + "DEFAULT": DefaultTagHandler, + "CREATED": CreatedTagHandler, + "UPDATED": UpdatedTagHandler, + "DELETED": DeletedTagHandler, + "VERSION": VersionTagHandler, + "UTC": UTCTagHandler, + "LOCAL": LocalTagHandler, + "NOTNULL": NotNullTagHandler, + "INDEX": IndexTagHandler, + "UNIQUE": UniqueTagHandler, + "CACHE": CacheTagHandler, + "NOCACHE": NoCacheTagHandler, + "COMMENT": CommentTagHandler, + "BELONGS_TO": BelongsToTagHandler, } ) @@ -256,7 +262,7 @@ func ExtendsTagHandler(ctx *tagContext) error { } fallthrough case reflect.Struct: - parentTable, err := ctx.engine.mapType(fieldValue) + parentTable, err := ctx.engine.mapType(ctx.parsingTables, fieldValue) if err != nil { return err } @@ -288,3 +294,44 @@ func NoCacheTagHandler(ctx *tagContext) error { } return nil } + +// BelongsToTagHandler describes belongs_to tag handler +func BelongsToTagHandler(ctx *tagContext) error { + if !isStruct(ctx.fieldValue.Type()) { + return errors.New("Tag belongs_to cannot be applied on non-struct field") + } + + ctx.col.AssociateType = core.AssociateBelongsTo + var t reflect.Value + if ctx.fieldValue.Kind() == reflect.Struct { + t = ctx.fieldValue + } else { + if ctx.fieldValue.Type().Kind() == reflect.Ptr && ctx.fieldValue.Type().Elem().Kind() == reflect.Struct { + if ctx.fieldValue.IsNil() { + t = reflect.New(ctx.fieldValue.Type().Elem()).Elem() + } else { + t = ctx.fieldValue + } + } else { + return errors.New("Only struct or ptr to struct field could add belongs_to flag") + } + } + + belongsT, err := ctx.engine.mapType(ctx.parsingTables, t) + if err != nil { + return err + } + pks := belongsT.PKColumns() + if len(pks) != 1 { + panic("unsupported non or composited primary key cascade") + return errors.New("blongs_to only should be as a tag of table has one primary key") + } + + ctx.col.AssociateTable = belongsT + ctx.col.SQLType = pks[0].SQLType + + if len(ctx.col.Name) == 0 { + ctx.col.Name = ctx.engine.ColumnMapper.Obj2Table(ctx.col.FieldName) + "_id" + } + return nil +} diff --git a/xorm_test.go b/xorm_test.go index 569bc681..f7581e06 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -12,6 +12,7 @@ import ( "github.com/go-xorm/core" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" _ "github.com/ziutek/mymysql/godrv" ) @@ -123,6 +124,8 @@ func TestMain(m *testing.M) { } func TestPing(t *testing.T) { + assert.NoError(t, prepareEngine()) + if err := testEngine.Ping(); err != nil { t.Fatal(err) }