refactor automaptype

This commit is contained in:
Lunny Xiao 2017-04-02 18:02:47 +08:00
parent a0042a7117
commit 7e70eb8222
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 168 additions and 119 deletions

View File

@ -217,10 +217,15 @@ func (engine *Engine) NoCascade() *Session {
} }
// MapCacher Set a table use a special cacher // MapCacher Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) { func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) error {
v := rValue(bean) v := rValue(bean)
tb := engine.autoMapType(v) tb, err := engine.autoMapType(v)
if err != nil {
return err
}
tb.Cacher = cacher tb.Cacher = cacher
return nil
} }
// NewDB provides an interface to operate database directly // NewDB provides an interface to operate database directly
@ -776,7 +781,7 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions) return session.Having(conditions)
} }
func (engine *Engine) autoMapType(v reflect.Value) *core.Table { func (engine *Engine) autoMapType(v reflect.Value) (*core.Table, error) {
t := v.Type() t := v.Type()
engine.mutex.Lock() engine.mutex.Lock()
defer engine.mutex.Unlock() defer engine.mutex.Unlock()
@ -785,8 +790,9 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
var err error var err error
table, err = engine.mapType(v) table, err = engine.mapType(v)
if err != nil { if err != nil {
engine.logger.Error(err) return nil, err
} else { }
engine.Tables[t] = table engine.Tables[t] = table
if engine.Cacher != nil { if engine.Cacher != nil {
if v.CanAddr() { if v.CanAddr() {
@ -796,13 +802,11 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
} }
} }
} }
} return table, nil
return table
} }
// GobRegister register one struct to gob for cache use // GobRegister register one struct to gob for cache use
func (engine *Engine) GobRegister(v interface{}) *Engine { func (engine *Engine) GobRegister(v interface{}) *Engine {
//fmt.Printf("Type: %[1]T => Data: %[1]#v\n", v)
gob.Register(v) gob.Register(v)
return engine return engine
} }
@ -813,10 +817,19 @@ type Table struct {
Name string Name string
} }
// IsValid if table is valid
func (t *Table) IsValid() bool {
return t.Table != nil && len(t.Name) > 0
}
// TableInfo get table info according to bean's content // TableInfo get table info according to bean's content
func (engine *Engine) TableInfo(bean interface{}) *Table { func (engine *Engine) TableInfo(bean interface{}) *Table {
v := rValue(bean) v := rValue(bean)
return &Table{engine.autoMapType(v), engine.tbName(v)} tb, err := engine.autoMapType(v)
if err != nil {
engine.logger.Error(err)
}
return &Table{tb, engine.tbName(v)}
} }
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@ -1066,8 +1079,21 @@ func (engine *Engine) IdOfV(rv reflect.Value) core.PK {
// IDOfV get id from one value of struct // IDOfV get id from one value of struct
func (engine *Engine) IDOfV(rv reflect.Value) core.PK { func (engine *Engine) IDOfV(rv reflect.Value) core.PK {
pk, err := engine.idOfV(rv)
if err != nil {
engine.logger.Error(err)
return nil
}
return pk
}
func (engine *Engine) idOfV(rv reflect.Value) (core.PK, error) {
v := reflect.Indirect(rv) v := reflect.Indirect(rv)
table := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil {
return nil, err
}
pk := make([]interface{}, len(table.PrimaryKeys)) pk := make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() { for i, col := range table.PKColumns() {
pkField := v.FieldByName(col.FieldName) pkField := v.FieldByName(col.FieldName)
@ -1080,7 +1106,7 @@ func (engine *Engine) IDOfV(rv reflect.Value) core.PK {
pk[i] = pkField.Uint() pk[i] = pkField.Uint()
} }
} }
return core.PK(pk) return core.PK(pk), nil
} }
// CreateIndexes create indexes // CreateIndexes create indexes
@ -1101,13 +1127,6 @@ func (engine *Engine) getCacher2(table *core.Table) core.Cacher {
return table.Cacher return table.Cacher
} }
func (engine *Engine) getCacher(v reflect.Value) core.Cacher {
if table := engine.autoMapType(v); table != nil {
return table.Cacher
}
return engine.Cacher
}
// ClearCacheBean if enabled cache, clear the cache bean // ClearCacheBean if enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id string) error { func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
v := rValue(bean) v := rValue(bean)
@ -1116,7 +1135,10 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
return errors.New("error params") return errors.New("error params")
} }
tableName := engine.tbName(v) tableName := engine.tbName(v)
table := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher cacher := table.Cacher
if cacher == nil { if cacher == nil {
cacher = engine.Cacher cacher = engine.Cacher
@ -1137,7 +1159,11 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
return errors.New("error params") return errors.New("error params")
} }
tableName := engine.tbName(v) tableName := engine.tbName(v)
table := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil {
return err
}
cacher := table.Cacher cacher := table.Cacher
if cacher == nil { if cacher == nil {
cacher = engine.Cacher cacher = engine.Cacher
@ -1157,7 +1183,11 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableName := engine.tbName(v) tableName := engine.tbName(v)
table := engine.autoMapType(v) table, err := engine.autoMapType(v)
fmt.Println(v, table, err)
if err != nil {
return err
}
s := engine.NewSession() s := engine.NewSession()
defer s.Close() defer s.Close()

View File

@ -606,14 +606,16 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
} }
} }
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
table := session.Engine.autoMapType(*fieldValue) table, err := session.Engine.autoMapType(*fieldValue)
if table != nil { if err != nil {
return nil, err
}
hasAssigned = true hasAssigned = true
if len(table.PrimaryKeys) != 1 { if len(table.PrimaryKeys) != 1 {
panic("unsupported non or composited primary key cascade") panic("unsupported non or composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(core.PK, len(table.PrimaryKeys))
var err error
pk[0], err = asKind(vv, rawValueType) pk[0], err = asKind(vv, rawValueType)
if err != nil { if err != nil {
return nil, err return nil, err
@ -638,9 +640,6 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
return nil, errors.New("cascade obj is not exist") return nil, errors.New("cascade obj is not exist")
} }
} }
} else {
session.Engine.logger.Error("unsupported struct type in Scan: ", fieldValue.Type().String())
}
} }
case reflect.Ptr: case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above // !nashtsai! TODO merge duplicated codes above

View File

@ -208,15 +208,17 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
v = x v = x
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
table := session.Engine.autoMapType(*fieldValue) table, err := session.Engine.autoMapType(*fieldValue)
if table != nil { if err != nil {
return err
}
// TODO: current only support 1 primary key // TODO: current only support 1 primary key
if len(table.PrimaryKeys) > 1 { if len(table.PrimaryKeys) > 1 {
panic("unsupported composited primary key cascade") panic("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(core.PK, len(table.PrimaryKeys))
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
var err error
pk[0], err = str2PK(string(data), rawValueType) pk[0], err = str2PK(string(data), rawValueType)
if err != nil { if err != nil {
return err return err
@ -240,9 +242,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("cascade obj is not exist") return errors.New("cascade obj is not exist")
} }
} }
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
}
} }
} }
case reflect.Ptr: case reflect.Ptr:
@ -493,13 +492,15 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
default: default:
if session.Statement.UseCascade { if session.Statement.UseCascade {
structInter := reflect.New(fieldType.Elem()) structInter := reflect.New(fieldType.Elem())
table := session.Engine.autoMapType(structInter.Elem()) table, err := session.Engine.autoMapType(structInter.Elem())
if table != nil { if err != nil {
return err
}
if len(table.PrimaryKeys) > 1 { if len(table.PrimaryKeys) > 1 {
panic("unsupported composited primary key cascade") panic("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(core.PK, len(table.PrimaryKeys))
var err error
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
pk[0], err = str2PK(string(data), rawValueType) pk[0], err = str2PK(string(data), rawValueType)
if err != nil { if err != nil {
@ -523,7 +524,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
return errors.New("cascade obj is not exist") return errors.New("cascade obj is not exist")
} }
} }
}
} else { } else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
} }
@ -603,7 +603,10 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return v.Value() return v.Value()
} }
fieldTable := session.Engine.autoMapType(fieldValue) fieldTable, err := session.Engine.autoMapType(fieldValue)
if err != nil {
return nil, err
}
if len(fieldTable.PrimaryKeys) == 1 { if len(fieldTable.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
return pkField.Interface(), nil return pkField.Interface(), nil

View File

@ -234,7 +234,11 @@ func (session *Session) noCacheFind(table *core.Table, containerValue reflect.Va
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
var newValue = newElemFunc(fields) var newValue = newElemFunc(fields)
dataStruct := rValue(newValue.Interface()) dataStruct := rValue(newValue.Interface())
return session.rows2Beans(rawRows, fields, len(fields), session.Engine.autoMapType(dataStruct), newElemFunc, containerValueSetFunc) tb, err := session.Engine.autoMapType(dataStruct)
if err != nil {
return err
}
return session.rows2Beans(rawRows, fields, len(fields), tb, newElemFunc, containerValueSetFunc)
} }
for rawRows.Next() { for rawRows.Next() {
@ -407,7 +411,10 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if rv.Kind() != reflect.Ptr { if rv.Kind() != reflect.Ptr {
rv = rv.Addr() rv = rv.Addr()
} }
id := session.Engine.IdOfV(rv) id, err := session.Engine.idOfV(rv)
if err != nil {
return err
}
sid, err := id.ToString() sid, err := id.ToString()
if err != nil { if err != nil {
return err return err

View File

@ -207,9 +207,14 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
return statement return statement
} }
func (statement *Statement) setRefValue(v reflect.Value) { func (statement *Statement) setRefValue(v reflect.Value) error {
statement.RefTable = statement.Engine.autoMapType(reflect.Indirect(v)) var err error
statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
if err != nil {
return err
}
statement.tableName = statement.Engine.tbName(v) statement.tableName = statement.Engine.tbName(v)
return nil
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct // Table tempororily set table name, the parameter could be a string or a pointer of struct
@ -219,7 +224,12 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
if t.Kind() == reflect.String { if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string) statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.autoMapType(v) var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v) statement.AltTableName = statement.Engine.tbName(v)
} }
return statement return statement