diff --git a/engine.go b/engine.go index 962df125..0c794512 100644 --- a/engine.go +++ b/engine.go @@ -923,7 +923,16 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { engine: engine, } - if strings.ToUpper(tags[0]) == "EXTENDS" { + if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") { + pStart := strings.Index(tags[0], "(") + if pStart > -1 && strings.HasSuffix(tags[0], ")") { + var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool { + return r == '\'' || r == '"' + }) + + ctx.params = []string{tagPrefix} + } + if err := ExtendsTagHandler(&ctx); err != nil { return nil, err } diff --git a/tag.go b/tag.go index eb87be78..6feb581a 100644 --- a/tag.go +++ b/tag.go @@ -244,6 +244,7 @@ func SQLTypeTagHandler(ctx *tagContext) error { // ExtendsTagHandler describes extends tag handler func ExtendsTagHandler(ctx *tagContext) error { var fieldValue = ctx.fieldValue + var isPtr = false switch fieldValue.Kind() { case reflect.Ptr: f := fieldValue.Type().Elem() @@ -254,6 +255,7 @@ func ExtendsTagHandler(ctx *tagContext) error { fieldValue = reflect.New(f).Elem() } } + isPtr = true fallthrough case reflect.Struct: parentTable, err := ctx.engine.mapType(fieldValue) @@ -262,6 +264,24 @@ func ExtendsTagHandler(ctx *tagContext) error { } for _, col := range parentTable.Columns() { col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) + + var tagPrefix = ctx.col.FieldName + if len(ctx.params) > 0 { + col.Nullable = isPtr + tagPrefix = ctx.params[0] + if col.IsPrimaryKey { + col.Name = ctx.col.FieldName + col.IsPrimaryKey = false + } else { + col.Name = fmt.Sprintf("%v%v", tagPrefix, col.Name) + } + } + + if col.Nullable { + col.IsAutoIncrement = false + col.IsPrimaryKey = false + } + ctx.table.AddColumn(col) for indexName, indexType := range col.Indexes { addIndex(indexName, ctx.table, col, indexType) diff --git a/tag_extends_test.go b/tag_extends_test.go index a12a9a55..5a8031f0 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -486,3 +486,123 @@ func TestExtends4(t *testing.T) { panic(err) } } + +type Size struct { + ID int64 `xorm:"int(4) 'id' pk autoincr"` + Width float32 `json:"width" xorm:"float 'Width'"` + Height float32 `json:"height" xorm:"float 'Height'"` +} + +type Book struct { + ID int64 `xorm:"int(4) 'id' pk autoincr"` + SizeOpen *Size `xorm:"extends('Open')"` + SizeClosed *Size `xorm:"extends('Closed')"` + Size *Size `xorm:"extends('')"` +} + +func TestExtends5(t *testing.T) { + assert.NoError(t, prepareEngine()) + err := testEngine.DropTables(&Book{}, &Size{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = testEngine.CreateTables(&Size{}, &Book{}) + if err != nil { + t.Error(err) + panic(err) + } + + var sc = Size{Width: 0.2, Height: 0.4} + var so = Size{Width: 0.2, Height: 0.8} + var s = Size{Width: 0.15, Height: 1.5} + var bk1 = Book{ + SizeOpen: &so, + SizeClosed: &sc, + Size: &s, + } + var bk2 = Book{ + SizeOpen: &so, + } + var bk3 = Book{ + SizeClosed: &sc, + Size: &s, + } + var bk4 = Book{} + var bk5 = Book{Size: &s} + _, err = testEngine.Insert(&sc, &so, &s, &bk1, &bk2, &bk3, &bk4, &bk5) + if err != nil { + t.Fatal(err) + } + + var books = map[int64]Book{ + bk1.ID: bk1, + bk2.ID: bk2, + bk3.ID: bk3, + bk4.ID: bk4, + bk5.ID: bk5, + } + + session := testEngine.NewSession() + defer session.Close() + + var mapper = testEngine.GetTableMapper().Obj2Table + var quote = testEngine.Quote + bookTableName := quote(testEngine.TableName(mapper("Book"), true)) + sizeTableName := quote(testEngine.TableName(mapper("Size"), true)) + + list := make([]Book, 0) + err = session. + Select(fmt.Sprintf( + "%s.%s, sc.%s AS %s, sc.%s AS %s, s.%s, s.%s", + quote(bookTableName), + quote("id"), + quote("Width"), + quote("ClosedWidth"), + quote("Height"), + quote("ClosedHeight"), + quote("Width"), + quote("Height"), + )). + Table(bookTableName). + Join( + "LEFT", + sizeTableName+" AS `sc`", + bookTableName+".`SizeClosed`=sc.`id`", + ). + Join( + "LEFT", + sizeTableName+" AS `s`", + bookTableName+".`Size`=s.`id`", + ). + Find(&list) + if err != nil { + t.Error(err) + panic(err) + } + + for _, book := range list { + if ok := assert.Equal(t, books[book.ID].SizeClosed.Width, book.SizeClosed.Width); !ok { + t.Error("Not bounded size closed") + panic("Not bounded size closed") + } + + if ok := assert.Equal(t, books[book.ID].SizeClosed.Height, book.SizeClosed.Height); !ok { + t.Error("Not bounded size closed") + panic("Not bounded size closed") + } + + if books[book.ID].Size != nil || book.Size != nil { + if ok := assert.Equal(t, books[book.ID].Size.Width, book.Size.Width); !ok { + t.Error("Not bounded size") + panic("Not bounded size") + } + + if ok := assert.Equal(t, books[book.ID].Size.Height, book.Size.Height); !ok { + t.Error("Not bounded size") + panic("Not bounded size") + } + } + } +}