Add support for association preloading

This commit is contained in:
Diego Sogari 2023-04-01 12:14:53 -03:00
parent d485abba57
commit 79554f640b
No known key found for this signature in database
GPG Key ID: 03A9A337B873A022
10 changed files with 912 additions and 30 deletions

View File

@ -1448,3 +1448,13 @@ func (engine *Engine) Transaction(f func(*Session) (interface{}, error)) (interf
return result, nil
}
func (engine *Engine) Preloads(preloads ...*Preload) *Session {
session := engine.NewSession()
session.isAutoClose = true
return session.Preloads(preloads...)
}
func (engine *Engine) Preload(path string) *Preload {
return NewPreload(path)
}

194
preload.go Normal file
View File

@ -0,0 +1,194 @@
package xorm
import (
"fmt"
"reflect"
"strings"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
type Preload struct {
path []string
cols []string
cond builder.Cond
noPrune bool
}
func NewPreload(path string) *Preload {
return &Preload{
path: strings.Split(path, "."),
cond: builder.NewCond(),
}
}
func (p *Preload) Cols(cols ...string) *Preload {
p.cols = append(p.cols, cols...)
return p
}
func (p *Preload) Where(cond builder.Cond) *Preload {
p.cond = p.cond.And(cond)
return p
}
func (p *Preload) NoPrune() *Preload {
p.noPrune = true
return p
}
type PreloadNode struct {
preload *Preload
children map[string]*PreloadNode
association *schemas.Association
ExtraCols []string
}
func NewPreloadNode() *PreloadNode {
return &PreloadNode{
children: make(map[string]*PreloadNode),
}
}
func (pn *PreloadNode) Add(preload *Preload) error {
return pn.add(preload, 0)
}
func (pn *PreloadNode) add(preload *Preload, index int) error {
if index == len(preload.path) {
if pn.preload != nil {
return fmt.Errorf("preload: duplicated path: %s", strings.Join(preload.path, ","))
}
pn.preload = preload
return nil
}
child, ok := pn.children[preload.path[index]]
if !ok {
child = NewPreloadNode()
pn.children[preload.path[index]] = child
}
return child.add(preload, index+1)
}
func (pn *PreloadNode) Validate(table *schemas.Table) error {
if pn.preload != nil {
for _, col := range pn.preload.cols {
if table.GetColumn(col) == nil {
return fmt.Errorf("preload: missing col %s in table %s", col, table.Name)
}
}
}
for name, node := range pn.children {
column := table.GetColumn(name)
if column == nil {
return fmt.Errorf("preload: missing field %s in struct %s", name, table.Type.Name())
}
if column.Association == nil {
return fmt.Errorf("preload: missing association in field %s", name)
}
if column.Association.JoinTable == nil && len(column.Association.SourceCol) > 0 {
pn.ExtraCols = append(pn.ExtraCols, column.Association.SourceCol)
}
if len(column.Association.TargetCol) > 0 {
pn.ExtraCols = append(pn.ExtraCols, table.PrimaryKeys[0]) // pk might be added many times, but that's ok
}
if column.Association.JoinTable == nil && len(column.Association.TargetCol) > 0 {
node.ExtraCols = append(node.ExtraCols, column.Association.TargetCol)
}
node.association = column.Association
if err := node.Validate(column.Association.RefTable); err != nil {
return err
}
}
return nil
}
func (pn *PreloadNode) Compute(session *Session, ownMap reflect.Value) error {
for _, node := range pn.children {
if err := node.compute(session, ownMap, reflect.Value{}); err != nil {
return err
}
}
return nil
}
func (pn *PreloadNode) compute(session *Session, ownMap, pruneMap reflect.Value) error {
// non-root node: pn.association is not nil
if err := pn.association.ValidateOwnMap(ownMap); err != nil {
return err
}
var joinMap reflect.Value
cond := pn.association.GetCond(ownMap)
if pn.association.JoinTable != nil {
var err error
cond, joinMap, err = pn.preloadJoin(session, cond)
if err != nil {
return err
}
}
refMap := pn.association.MakeRefMap()
preloadSession := session.Engine().Cols(pn.ExtraCols...).Where(cond)
if pn.preload != nil {
preloadSession.Cols(pn.preload.cols...).Where(pn.preload.cond)
}
if err := preloadSession.Find(refMap.Interface()); err != nil {
return err
}
var refPruneMap reflect.Value
if len(pn.children) > 0 && !(pn.preload != nil && (len(pn.preload.cols) > 0 || pn.preload.noPrune)) {
refPruneMap = reflect.MakeMap(reflect.MapOf(refMap.Type().Key(), reflect.TypeOf(true)))
refIter := refMap.MapRange()
for refIter.Next() {
refPruneMap.SetMapIndex(refIter.Key(), reflect.ValueOf(true))
}
}
for _, node := range pn.children {
if err := node.compute(session, refMap, refPruneMap); err != nil {
return err
}
}
if refPruneMap.IsValid() {
pruneIter := refPruneMap.MapRange()
for pruneIter.Next() {
refMap.SetMapIndex(pruneIter.Key(), reflect.Value{})
}
}
pn.association.Link(ownMap, refMap, pruneMap, joinMap)
return nil
}
func (pn *PreloadNode) preloadJoin(session *Session, cond builder.Cond) (builder.Cond, reflect.Value, error) {
joinSlicePtr := pn.association.MakeJoinSlice()
if err := session.Engine().
Table(pn.association.JoinTable.Name).Where(cond).
Cols(pn.association.SourceCol, pn.association.TargetCol).
Find(joinSlicePtr.Interface()); err != nil {
return nil, reflect.Value{}, err
}
joinSlice := joinSlicePtr.Elem()
joinMap := pn.association.MakeJoinMap()
for i := 0; i < joinSlice.Len(); i++ {
entry := joinSlice.Index(i)
pkSlice := joinMap.MapIndex(entry.Field(1))
if !pkSlice.IsValid() {
pkSlice = reflect.MakeSlice(reflect.SliceOf(pn.association.OwnPkType), 0, 0)
}
joinMap.SetMapIndex(entry.Field(1), reflect.Append(pkSlice, entry.Field(0)))
}
var refPks []interface{}
iter := joinMap.MapRange()
joinMap.MapKeys()
for iter.Next() {
refPks = append(refPks, iter.Key().Interface())
}
cond = builder.In(pn.association.RefTable.PrimaryKeys[0], refPks)
return cond, joinMap, nil
}

276
preload_test.go Normal file
View File

@ -0,0 +1,276 @@
package xorm
import (
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"sort"
"testing"
"xorm.io/builder"
)
// https://gitea.com/xorm/xorm/issues/2240
type Employee struct {
Id int64
Name string
BuddyId *int64
ManagerId *int64
Buddy *Employee `xorm:"belongs_to(buddy_id)"`
Apprentice *Employee `xorm:"has_one(buddy_id)"`
Manager *Employee `xorm:"belongs_to(manager_id)"`
Subordinates []*Employee `xorm:"has_many(manager_id)"`
Indications []*Employee `xorm:"many_to_many(employee_indication, indicator_id, indicated_id)"`
IndicatedBy []*Employee `xorm:"many_to_many(employee_indication, indicated_id, indicator_id)"`
}
func TestPreload(t *testing.T) {
engine, err := NewEngine("sqlite3", ":memory:")
require.NoError(t, err)
const sql = `
create table employee (
id integer primary key autoincrement,
name text not null,
buddy_id integer references employee(id) check (buddy_id <> id) unique,
manager_id integer references employee(id) check (manager_id <> id)
);
create table employee_indication (
indicator_id integer not null references employee(id),
indicated_id integer not null references employee(id),
primary key (indicator_id, indicated_id),
check (indicator_id <> indicated_id)
);
insert into employee (name) values ('John'), ('Bob');
insert into employee (name,manager_id) values ('Alice',1), ('Riya',2);
insert into employee (name,manager_id,buddy_id) values ('Emilie',1,3), ('Cynthia',2,4);
insert into employee_indication values (1,2), (1,3), (2,3), (2,4), (2,5), (3,5), (3,6);
-- John manages Alice and Emilie
-- Bob manages Riya and Cynthia
-- Alice is buddy of Emilie
-- Riya is buddy of Cynthia
-- John indicated Bob and Alice
-- Bob indicated Alice, Riya and Emilie
-- Alice indicated Emilie and Cynthia
`
_, err = engine.Exec(sql)
require.NoError(t, err)
var employee Employee
_, err = engine.Preloads(
engine.Preload("Indications.Buddy").Cols("name"),
engine.Preload("Indications").NoPrune(),
).Cols("name").Where(builder.Eq{"id": 2}).Get(&employee)
require.NoError(t, err)
// order is not preserved when preloading, so we compensate for it in the test
sort.Slice(employee.Indications, func(i, j int) bool {
return employee.Indications[i].Id < employee.Indications[j].Id
})
assert.Equal(t, Employee{
Id: 2,
Name: "Bob",
Indications: []*Employee{
{Id: 3},
{Id: 4},
{
Id: 5,
BuddyId: &[]int64{3}[0],
Buddy: &Employee{
Id: 3,
Name: "Alice",
},
},
},
}, employee)
var employees []*Employee
err = engine.Preloads(
// 1. preload the names of all subordinates of this employee's manager
engine.Preload("Manager.Subordinates").Cols("name"),
// 2. preload the name of the buddy of each employee indicated by this employee
engine.Preload("Indications.Buddy").Cols("name"),
// 3. preload the names of all employees who indicated this employee's apprentice, except non-subordinates
engine.Preload("Apprentice.IndicatedBy").Cols("name").Where(builder.NotNull{"manager_id"}),
// 4. preload the names of:
// all employees who don't have a maanger and were indicated by:
engine.Preload("Subordinates.IndicatedBy.Indications").Where(builder.IsNull{"manager_id"}).Cols("name"),
// employees whose name is 4 letters long and indicated:
engine.Preload("Subordinates.IndicatedBy").Where(builder.Like{"name", "____"}),
// this employee's subordinates who don't have a buddy
engine.Preload("Subordinates").Where(builder.IsNull{"buddy_id"}),
// 0. find the names of all employees
).Cols("name").Find(&employees)
require.NoError(t, err)
// order is not preserved when preloading, so we compensate for it in the test
for k := 2; k < 6; k++ {
sort.Slice(employees[k].Manager.Subordinates, func(i, j int) bool {
return employees[k].Manager.Subordinates[i].Id < employees[k].Manager.Subordinates[j].Id
})
}
sort.Slice(employees[2].Indications, func(i, j int) bool {
return employees[2].Indications[i].Id < employees[2].Indications[j].Id
})
expected := []*Employee{
{
Id: 1,
Name: "John",
Subordinates: []*Employee{
{
Id: 3,
ManagerId: &[]int64{1}[0],
IndicatedBy: []*Employee{
{
Id: 1,
Indications: []*Employee{
{
Id: 2,
Name: "Bob",
},
},
},
},
},
},
},
{
Id: 2,
Name: "Bob",
Indications: []*Employee{
{
Id: 5,
BuddyId: &[]int64{3}[0],
Buddy: &Employee{
Id: 3,
Name: "Alice",
},
},
},
},
{
Id: 3,
Name: "Alice",
ManagerId: &[]int64{1}[0],
Manager: &Employee{
Id: 1,
Subordinates: []*Employee{
{
Id: 3,
Name: "Alice",
ManagerId: &[]int64{1}[0],
},
{
Id: 5,
Name: "Emilie",
ManagerId: &[]int64{1}[0],
},
},
},
Apprentice: &Employee{
Id: 5,
BuddyId: &[]int64{3}[0],
IndicatedBy: []*Employee{
{
Id: 3,
Name: "Alice",
},
},
},
Indications: []*Employee{
{
Id: 5,
BuddyId: &[]int64{3}[0],
Buddy: &Employee{
Id: 3,
Name: "Alice",
},
},
{
Id: 6,
BuddyId: &[]int64{4}[0],
Buddy: &Employee{
Id: 4,
Name: "Riya",
},
},
},
},
{
Id: 4,
Name: "Riya",
ManagerId: &[]int64{2}[0],
Manager: &Employee{
Id: 2,
Subordinates: []*Employee{
{
Id: 4,
Name: "Riya",
ManagerId: &[]int64{2}[0],
},
{
Id: 6,
Name: "Cynthia",
ManagerId: &[]int64{2}[0],
},
},
},
Apprentice: &Employee{
Id: 6,
BuddyId: &[]int64{4}[0],
IndicatedBy: []*Employee{
{
Id: 3,
Name: "Alice",
},
},
},
},
{
Id: 5,
Name: "Emilie",
ManagerId: &[]int64{1}[0],
Manager: &Employee{
Id: 1,
Subordinates: []*Employee{
{
Id: 3,
Name: "Alice",
ManagerId: &[]int64{1}[0],
},
{
Id: 5,
Name: "Emilie",
ManagerId: &[]int64{1}[0],
},
},
},
},
{
Id: 6,
Name: "Cynthia",
ManagerId: &[]int64{2}[0],
Manager: &Employee{
Id: 2,
Subordinates: []*Employee{
{
Id: 4,
Name: "Riya",
ManagerId: &[]int64{2}[0],
},
{
Id: 6,
Name: "Cynthia",
ManagerId: &[]int64{2}[0],
},
},
},
},
}
assert.Equal(t, expected, employees)
}

167
schemas/association.go Normal file
View File

@ -0,0 +1,167 @@
package schemas
import (
"fmt"
"reflect"
"xorm.io/builder"
)
type Association struct {
OwnTable *Table
OwnColumn *Column
OwnPkType reflect.Type
RefTable *Table
RefPkType reflect.Type
JoinTable *Table // many_to_many
SourceCol string // belongs_to, many_to_many
TargetCol string // has_one, has_many, many_to_many
}
func (a *Association) MakeJoinSlice() reflect.Value {
return reflect.New(reflect.SliceOf(a.JoinTable.Type))
}
func (a *Association) MakeJoinMap() reflect.Value {
return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.SliceOf(a.OwnPkType)))
}
func (a *Association) MakeRefMap() reflect.Value {
return reflect.MakeMap(reflect.MapOf(a.RefPkType, reflect.PointerTo(a.RefTable.Type)))
}
func (a *Association) ValidateOwnMap(ownMap reflect.Value) error {
if ownMap.Type() != reflect.MapOf(a.OwnPkType, reflect.PointerTo(a.OwnTable.Type)) {
return fmt.Errorf("wrong map type: %v", ownMap.Type())
}
return nil
}
func (a *Association) GetCond(ownMap reflect.Value) builder.Cond {
if a.JoinTable != nil {
return a.condManyToMany(ownMap)
}
if len(a.SourceCol) > 0 {
return a.condBelongsTo(ownMap)
}
return a.condHasOneOrMany(ownMap)
}
func (a *Association) condBelongsTo(ownMap reflect.Value) builder.Cond {
pkMap := make(map[interface{}]bool)
fkCol := a.OwnTable.GetColumn(a.SourceCol)
iter := ownMap.MapRange()
for iter.Next() {
structPtr := iter.Value()
fk, _ := fkCol.ValueOfV(&structPtr)
if fk.Type().Kind() == reflect.Pointer {
if fk.IsNil() {
continue
}
*fk = fk.Elem()
}
// don't check fk.IsZero(), because it might be a valid fk value
pkMap[fk.Interface()] = true
}
pks := make([]interface{}, 0, len(pkMap))
for pk := range pkMap {
pks = append(pks, pk)
}
return builder.In(a.RefTable.PrimaryKeys[0], pks)
}
func (a *Association) condHasOneOrMany(ownMap reflect.Value) builder.Cond {
var pks []interface{}
iter := ownMap.MapRange()
for iter.Next() {
pks = append(pks, iter.Key().Interface())
}
return builder.In(a.TargetCol, pks)
}
func (a *Association) condManyToMany(ownMap reflect.Value) builder.Cond {
var pks []interface{}
iter := ownMap.MapRange()
for iter.Next() {
pks = append(pks, iter.Key().Interface())
}
return builder.In(a.SourceCol, pks)
}
func (a *Association) Link(ownMap, refMap, pruneMap, joinMap reflect.Value) {
if a.JoinTable != nil {
a.linkManyToMany(ownMap, refMap, pruneMap, joinMap)
} else if len(a.SourceCol) > 0 {
a.linkBelongsTo(ownMap, refMap, pruneMap)
} else {
a.linkHasOneOrMany(ownMap, refMap, pruneMap)
}
}
func (a *Association) linkBelongsTo(ownMap, refMap, pruneMap reflect.Value) {
fkCol := a.OwnTable.GetColumn(a.SourceCol)
iter := ownMap.MapRange()
for iter.Next() {
structPtr := iter.Value()
fk, _ := fkCol.ValueOfV(&structPtr)
if fk.Type().Kind() == reflect.Pointer {
if fk.IsNil() {
continue
}
*fk = fk.Elem()
}
// don't check fk.IsZero(), because it might be a valid fk value
refStructPtr := refMap.MapIndex(*fk)
if refStructPtr.IsValid() {
refField, _ := a.OwnColumn.ValueOfV(&structPtr)
refField.Set(refStructPtr)
if pruneMap.IsValid() {
pruneMap.SetMapIndex(iter.Key(), reflect.Value{}) // do not prune this key
}
}
}
}
func (a *Association) linkHasOneOrMany(ownMap, refMap, pruneMap reflect.Value) {
hasMany := a.OwnColumn.FieldType.Kind() == reflect.Slice
fkCol := a.RefTable.GetColumn(a.TargetCol)
iter := refMap.MapRange()
for iter.Next() {
refStructPtr := iter.Value()
fk, _ := fkCol.ValueOfV(&refStructPtr)
if fk.Type().Kind() == reflect.Pointer {
if fk.IsNil() {
continue
}
*fk = fk.Elem()
}
// don't check fk.IsZero(), because it might be a valid fk value
structPtr := ownMap.MapIndex(*fk) // structPtr should be valid at this point
refField, _ := a.OwnColumn.ValueOfV(&structPtr)
if hasMany {
refField.Set(reflect.Append(*refField, refStructPtr))
} else {
refField.Set(refStructPtr)
}
if pruneMap.IsValid() {
pruneMap.SetMapIndex(*fk, reflect.Value{}) // do not prune this key
}
}
}
func (a *Association) linkManyToMany(ownMap, refMap, pruneMap, joinMap reflect.Value) {
iter := refMap.MapRange()
for iter.Next() {
refStructPtr := iter.Value()
refPk := iter.Key() // refPk should not be a pointer
pkSlice := joinMap.MapIndex(refPk) // pkSlice should be valid at this point
for i := 0; i < pkSlice.Len(); i++ {
pk := pkSlice.Index(i)
structPtr := ownMap.MapIndex(pk) // structPtr should be valid at this point
refField, _ := a.OwnColumn.ValueOfV(&structPtr)
refField.Set(reflect.Append(*refField, refStructPtr))
if pruneMap.IsValid() {
pruneMap.SetMapIndex(pk, reflect.Value{}) // do not prune this key
}
}
}
}

View File

@ -21,9 +21,9 @@ const (
// Column defines database column
type Column struct {
Name string
TableName string
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
FieldName string // Available only when parsed from a struct
FieldIndex []int // Available only when parsed from a struct
FieldType reflect.Type // Available only when parsed from a struct
SQLType SQLType
IsJSON bool
Length int64
@ -45,6 +45,7 @@ type Column struct {
DisableTimeZone bool
TimeZone *time.Location // column specified time zone
Comment string
Association *Association
}
// NewColumn creates a new column
@ -52,7 +53,6 @@ func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullab
return &Column{
Name: name,
IsJSON: sqlType.IsJson(),
TableName: "",
FieldName: fieldName,
SQLType: sqlType,
Length: len1,

View File

@ -87,6 +87,7 @@ type Session struct {
ctx context.Context
sessionType sessionType
preloadNode *PreloadNode
}
func newSessionID() string {
@ -800,3 +801,17 @@ func (session *Session) NoVersionCheck() *Session {
session.statement.CheckVersion = false
return session
}
// Preloads adds preloads
func (session *Session) Preloads(preloads ...*Preload) *Session {
if session.preloadNode == nil {
session.preloadNode = NewPreloadNode()
}
for _, preload := range preloads {
if err := session.preloadNode.Add(preload); err != nil {
session.statement.LastError = err
break
}
}
return session
}

View File

@ -91,6 +91,12 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem()
if session.preloadNode != nil {
if sliceElementType.Kind() != reflect.Ptr || sliceElementType.Elem().Kind() != reflect.Struct {
return errors.New("preloading requires a pointer to a slice or a map of struct pointer")
}
}
tp := tpStruct
if session.statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr {
@ -141,6 +147,18 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
}
if session.preloadNode != nil {
if err := session.preloadNode.Validate(table); err != nil {
return err
}
// we need the columns required for the preloads
if !session.statement.ColumnMap.IsEmpty() {
for _, k := range session.preloadNode.ExtraCols {
session.statement.ColumnMap.Add(k)
}
}
}
sqlStr, args, err := session.statement.GenFindSQL(autoCond)
if err != nil {
return err
@ -158,7 +176,26 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
}
return session.noCacheFind(table, sliceValue, sqlStr, args...)
err = session.noCacheFind(table, sliceValue, sqlStr, args...)
if err != nil {
return err
}
if session.preloadNode != nil {
if isSlice {
// convert to a map before preloading
originalSliceValue := sliceValue
pkColumn := table.PKColumns()[0]
sliceValue = reflect.MakeMap(reflect.MapOf(pkColumn.FieldType, sliceElementType))
for i := 0; i < originalSliceValue.Len(); i++ {
element := originalSliceValue.Index(i)
pkValue, _ := pkColumn.ValueOfV(&element)
sliceValue.SetMapIndex(*pkValue, element)
}
}
return session.preloadNode.Compute(session, sliceValue)
}
return nil
}
func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect.Value, sqlStr string, args ...interface{}) error {

View File

@ -71,12 +71,28 @@ func (session *Session) get(beans ...interface{}) (bool, error) {
if err := session.statement.SetRefBean(beans[0]); err != nil {
return false, err
}
} else if session.preloadNode != nil {
return false, errors.New("preloading requires a pointer to struct")
}
var sqlStr string
var args []interface{}
var err error
table := session.statement.RefTable
if session.preloadNode != nil {
if err := session.preloadNode.Validate(table); err != nil {
return false, err
}
// we need the columns required for the preloads
if !session.statement.ColumnMap.IsEmpty() {
for _, k := range session.preloadNode.ExtraCols {
session.statement.ColumnMap.Add(k)
}
}
}
if session.statement.RawSQL == "" {
if len(session.statement.TableName()) == 0 {
return false, ErrTableNotFound
@ -91,8 +107,6 @@ func (session *Session) get(beans ...interface{}) (bool, error) {
args = session.statement.RawParams
}
table := session.statement.RefTable
if session.statement.ColumnMap.IsEmpty() && session.canCache() && isStruct {
if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil &&
!session.statement.GetUnscoped() {
@ -122,6 +136,19 @@ func (session *Session) get(beans ...interface{}) (bool, error) {
return has, err
}
if session.preloadNode != nil {
// convert to a map before preloading
pkColumn := table.PKColumns()[0]
sliceValue := reflect.MakeMap(reflect.MapOf(pkColumn.FieldType, beanValue.Type()))
dataStruct := beanValue.Elem()
pkValue, _ := pkColumn.ValueOfV(&dataStruct)
sliceValue.SetMapIndex(*pkValue, beanValue)
if err := session.preloadNode.Compute(session, sliceValue); err != nil {
return false, err
}
}
if context != nil && isStruct {
context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0])
}

View File

@ -84,7 +84,7 @@ func (parser *Parser) SetIdentifier(identifier string) {
// ParseWithCache parse a struct with cache
func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) {
t := v.Type()
t := reflect.Indirect(v).Type()
tableI, ok := parser.tableCache.Load(t)
if ok {
return tableI.(*schemas.Table), nil
@ -172,6 +172,7 @@ func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructFi
field.Name, sqlType, sqlType.DefaultLength,
sqlType.DefaultLength2, true)
col.FieldIndex = []int{fieldIndex}
col.FieldType = field.Type
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
col.IsAutoIncrement = true
@ -185,6 +186,7 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f
col := &schemas.Column{
FieldName: field.Name,
FieldIndex: []int{fieldIndex},
FieldType: field.Type,
Nullable: true,
IsPrimaryKey: false,
IsAutoIncrement: false,

View File

@ -5,6 +5,7 @@
package tags
import (
"errors"
"fmt"
"reflect"
"strconv"
@ -101,28 +102,32 @@ type Handler func(ctx *Context) error
// defaultTagHandlers enumerates all the default tag handler
var defaultTagHandlers = map[string]Handler{
"-": IgnoreHandler,
"<-": OnlyFromDBTagHandler,
"->": OnlyToDBTagHandler,
"PK": PKTagHandler,
"NULL": NULLTagHandler,
"NOT": NotTagHandler,
"AUTOINCR": AutoIncrTagHandler,
"DEFAULT": DefaultTagHandler,
"CREATED": CreatedTagHandler,
"UPDATED": UpdatedTagHandler,
"DELETED": DeletedTagHandler,
"VERSION": VersionTagHandler,
"UTC": UTCTagHandler,
"LOCAL": LocalTagHandler,
"NOTNULL": NotNullTagHandler,
"INDEX": IndexTagHandler,
"UNIQUE": UniqueTagHandler,
"CACHE": CacheTagHandler,
"NOCACHE": NoCacheTagHandler,
"COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler,
"-": IgnoreHandler,
"<-": OnlyFromDBTagHandler,
"->": OnlyToDBTagHandler,
"PK": PKTagHandler,
"NULL": NULLTagHandler,
"NOT": NotTagHandler,
"AUTOINCR": AutoIncrTagHandler,
"DEFAULT": DefaultTagHandler,
"CREATED": CreatedTagHandler,
"UPDATED": UpdatedTagHandler,
"DELETED": DeletedTagHandler,
"VERSION": VersionTagHandler,
"UTC": UTCTagHandler,
"LOCAL": LocalTagHandler,
"NOTNULL": NotNullTagHandler,
"INDEX": IndexTagHandler,
"UNIQUE": UniqueTagHandler,
"CACHE": CacheTagHandler,
"NOCACHE": NoCacheTagHandler,
"COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler,
"BELONGS_TO": BelongsToTagHandler,
"HAS_ONE": HasOneTagHandler,
"HAS_MANY": HasManyTagHandler,
"MANY_TO_MANY": ManyToManyTagHandler,
}
func init() {
@ -396,3 +401,152 @@ func NoCacheTagHandler(ctx *Context) error {
}
return nil
}
func isStructPtr(t reflect.Type) bool {
return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct
}
func isStructPtrSlice(t reflect.Type) bool {
return t.Kind() == reflect.Slice && isStructPtr(t.Elem())
}
func getRefTableAndPks(ctx *Context, refType reflect.Type) ([]*schemas.Column, *schemas.Table, []*schemas.Column, error) {
pks := ctx.table.PKColumns()
if len(pks) != 1 {
return nil, nil, nil, errors.New("a single-column primary key must be declared before the association")
}
refTable := ctx.table // self-reference case
refPks := pks
if refTable.Type != refType {
var err error
refTable, err = ctx.parser.ParseWithCache(reflect.New(refType).Elem())
if err != nil {
return nil, nil, nil, err
}
refPks = refTable.PKColumns()
if len(refPks) != 1 {
return nil, nil, nil, errors.New("only single-column primary key is supported in associations")
}
}
return pks, refTable, refPks, nil
}
// BelongsToTagHandler describes belongs_to tag handler
func BelongsToTagHandler(ctx *Context) error {
if !isStructPtr(ctx.fieldValue.Type()) {
return errors.New("tag belongs_to can only be applied to struct pointer field")
}
if len(ctx.params) != 1 {
return errors.New("tag belongs_to requires one parameter")
}
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem())
if err != nil {
return err
}
ctx.col.Association = &schemas.Association{
OwnTable: ctx.table,
OwnColumn: ctx.col,
OwnPkType: pks[0].FieldType,
RefTable: refTable,
RefPkType: refPks[0].FieldType,
SourceCol: ctx.params[0],
}
ctx.col.Name = ctx.col.FieldName
return nil
}
// HasOneTagHandler describes has_one tag handler
func HasOneTagHandler(ctx *Context) error {
if !isStructPtr(ctx.fieldValue.Type()) {
return errors.New("tag has_one can only be applied to struct pointer field")
}
if len(ctx.params) != 1 {
return errors.New("tag has_one requires one parameter")
}
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem())
if err != nil {
return err
}
ctx.col.Association = &schemas.Association{
OwnTable: ctx.table,
OwnColumn: ctx.col,
OwnPkType: pks[0].FieldType,
RefTable: refTable,
RefPkType: refPks[0].FieldType,
TargetCol: ctx.params[0],
}
ctx.col.Name = ctx.col.FieldName
return nil
}
// HasManyTagHandler describes has_many tag handler
func HasManyTagHandler(ctx *Context) error {
if !isStructPtrSlice(ctx.fieldValue.Type()) {
return errors.New("tag has_many can only be applied to slice of struct pointer field")
}
if len(ctx.params) != 1 {
return errors.New("tag has_many requires one parameter")
}
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem().Elem())
if err != nil {
return err
}
ctx.col.Association = &schemas.Association{
OwnTable: ctx.table,
OwnColumn: ctx.col,
OwnPkType: pks[0].FieldType,
RefTable: refTable,
RefPkType: refPks[0].FieldType,
TargetCol: ctx.params[0],
}
ctx.col.Name = ctx.col.FieldName
return nil
}
// ManyToManyTagHandler describes many_to_many tag handler
func ManyToManyTagHandler(ctx *Context) error {
if !isStructPtrSlice(ctx.fieldValue.Type()) {
return errors.New("tag many_to_many can only be applied to slice of struct pointer field")
}
if len(ctx.params) != 3 {
return errors.New("tag many_to_many requires 3 parameters")
}
pks, refTable, refPks, err := getRefTableAndPks(ctx, ctx.fieldValue.Type().Elem().Elem())
if err != nil {
return err
}
joinType := reflect.StructOf([]reflect.StructField{
{
Name: ctx.parser.GetColumnMapper().Table2Obj(ctx.params[1]),
Type: pks[0].FieldType,
},
{
Name: ctx.parser.GetColumnMapper().Table2Obj(ctx.params[2]),
Type: refPks[0].FieldType,
},
})
joinTable, err := ctx.parser.ParseWithCache(reflect.New(joinType).Elem())
if err != nil {
return err
}
joinTable.Name = ctx.params[0]
ctx.col.Association = &schemas.Association{
OwnTable: ctx.table,
OwnColumn: ctx.col,
OwnPkType: pks[0].FieldType,
RefTable: refTable,
RefPkType: refPks[0].FieldType,
JoinTable: joinTable,
SourceCol: ctx.params[1],
TargetCol: ctx.params[2],
}
ctx.col.Name = ctx.col.FieldName
return nil
}