resolved #89: if struct has func, the struct's name is the result

This commit is contained in:
Lunny Xiao 2014-04-08 16:46:23 +08:00
parent 9d5f834eb2
commit 9b23e7d6c0
5 changed files with 83 additions and 33 deletions

View File

@ -3990,6 +3990,28 @@ func testCompositeKey2(engine *Engine, t *testing.T) {
} }
} }
type CustomTableName struct {
Id int64
Name string
}
func (c *CustomTableName) TableName() string {
return "customtablename"
}
func testCustomTableName(engine *Engine, t *testing.T) {
c := new(CustomTableName)
err := engine.DropTables(c)
if err != nil {
t.Error(err)
}
err = engine.CreateTables(c)
if err != nil {
t.Error(err)
}
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -4100,6 +4122,8 @@ func testAll2(engine *Engine, t *testing.T) {
testProcessors(engine, t) testProcessors(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
fmt.Println("-------------- testCustomTableName --------------")
testCustomTableName(engine, t)
} }
// !nash! the 3rd set of the test is intended for non-cache enabled engine // !nash! the 3rd set of the test is intended for non-cache enabled engine

View File

@ -157,9 +157,9 @@ func (engine *Engine) NoCascade() *Session {
// Set a table use a special cacher // Set a table use a special cacher
func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) {
t := rType(bean) v := rValue(bean)
engine.autoMapType(t) engine.autoMapType(v)
engine.Tables[t].Cacher = cacher engine.Tables[v.Type()].Cacher = cacher
} }
// OpenDB provides a interface to operate database directly. // OpenDB provides a interface to operate database directly.
@ -435,12 +435,13 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions) return session.Having(conditions)
} }
func (engine *Engine) autoMapType(t reflect.Type) *Table { func (engine *Engine) autoMapType(v reflect.Value) *Table {
t := v.Type()
engine.mutex.RLock() engine.mutex.RLock()
table, ok := engine.Tables[t] table, ok := engine.Tables[t]
engine.mutex.RUnlock() engine.mutex.RUnlock()
if !ok { if !ok {
table = engine.mapType(t) table = engine.mapType(v)
engine.mutex.Lock() engine.mutex.Lock()
engine.Tables[t] = table engine.Tables[t] = table
engine.mutex.Unlock() engine.mutex.Unlock()
@ -449,8 +450,8 @@ func (engine *Engine) autoMapType(t reflect.Type) *Table {
} }
func (engine *Engine) autoMap(bean interface{}) *Table { func (engine *Engine) autoMap(bean interface{}) *Table {
t := rType(bean) v := rValue(bean)
return engine.autoMapType(t) return engine.autoMapType(v)
} }
func (engine *Engine) newTable() *Table { func (engine *Engine) newTable() *Table {
@ -475,9 +476,24 @@ func addIndex(indexName string, table *Table, col *Column, indexType int) {
} }
} }
func (engine *Engine) mapType(t reflect.Type) *Table { func (engine *Engine) mapType(v reflect.Value) *Table {
t := v.Type()
table := engine.newTable() table := engine.newTable()
table.Name = engine.tableMapper.Obj2Table(t.Name()) method := v.MethodByName("TableName")
if !method.IsValid() {
method = v.Addr().MethodByName("TableName")
}
if method.IsValid() {
params := []reflect.Value{}
results := method.Call(params)
if len(results) == 1 {
table.Name = results[0].Interface().(string)
}
}
if table.Name == "" {
table.Name = engine.tableMapper.Obj2Table(t.Name())
}
table.Type = t table.Type = t
var idFieldColName string var idFieldColName string
@ -487,7 +503,8 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
tag := t.Field(i).Tag tag := t.Field(i).Tag
ormTagStr := tag.Get(engine.TagIdentifier) ormTagStr := tag.Get(engine.TagIdentifier)
var col *Column var col *Column
fieldType := t.Field(i).Type fieldValue := v.Field(i)
fieldType := fieldValue.Type()
if ormTagStr != "" { if ormTagStr != "" {
col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false,
@ -500,7 +517,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
} }
if (strings.ToUpper(tags[0]) == "EXTENDS") && if (strings.ToUpper(tags[0]) == "EXTENDS") &&
(fieldType.Kind() == reflect.Struct) { (fieldType.Kind() == reflect.Struct) {
parentTable := engine.mapType(fieldType) parentTable := engine.mapType(fieldValue)
for name, col := range parentTable.Columns { for name, col := range parentTable.Columns {
col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName)
table.Columns[strings.ToLower(name)] = col table.Columns[strings.ToLower(name)] = col
@ -671,19 +688,20 @@ func (engine *Engine) mapping(beans ...interface{}) (e error) {
engine.mutex.Lock() engine.mutex.Lock()
defer engine.mutex.Unlock() defer engine.mutex.Unlock()
for _, bean := range beans { for _, bean := range beans {
t := rType(bean) v := rValue(bean)
engine.Tables[t] = engine.mapType(t) engine.Tables[v.Type()] = engine.mapType(v)
} }
return return
} }
// If a table has any reocrd // If a table has any reocrd
func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) {
t := rType(bean) v := rValue(bean)
t := v.Type()
if t.Kind() != reflect.Struct { if t.Kind() != reflect.Struct {
return false, errors.New("bean should be a struct or struct's point") return false, errors.New("bean should be a struct or struct's point")
} }
engine.autoMapType(t) engine.autoMapType(v)
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
rows, err := session.Count(bean) rows, err := session.Count(bean)
@ -692,11 +710,11 @@ func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) {
// If a table is exist // If a table is exist
func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { func (engine *Engine) IsTableExist(bean interface{}) (bool, error) {
t := rType(bean) v := rValue(bean)
if t.Kind() != reflect.Struct { if v.Type().Kind() != reflect.Struct {
return false, errors.New("bean should be a struct or struct's point") return false, errors.New("bean should be a struct or struct's point")
} }
table := engine.autoMapType(t) table := engine.autoMapType(v)
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
has, err := session.isTableExist(table.Name) has, err := session.isTableExist(table.Name)

View File

@ -37,9 +37,14 @@ func makeArray(elem string, count int) []string {
return res return res
} }
func rValue(bean interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(bean))
}
func rType(bean interface{}) reflect.Type { func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface()) //return reflect.TypeOf(sliceValue.Interface())
return sliceValue.Type()
} }
func structName(v reflect.Type) string { func structName(v reflect.Type) string {

View File

@ -358,12 +358,12 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) {
} }
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := reflect.Indirect(reflect.ValueOf(obj)) dataStruct := rValue(obj)
if dataStruct.Kind() != reflect.Struct { if dataStruct.Kind() != reflect.Struct {
return errors.New("Expected a pointer to a struct") return errors.New("Expected a pointer to a struct")
} }
table := session.Engine.autoMapType(rType(obj)) table := session.Engine.autoMapType(dataStruct)
for key, data := range objMap { for key, data := range objMap {
key = strings.ToLower(key) key = strings.ToLower(key)
@ -1017,12 +1017,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
if session.Statement.RefTable == nil { if session.Statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
table = session.Engine.autoMapType(sliceElementType.Elem()) pv := reflect.New(sliceElementType.Elem())
table = session.Engine.autoMapType(pv.Elem())
} else { } else {
return errors.New("slice type") return errors.New("slice type")
} }
} else if sliceElementType.Kind() == reflect.Struct { } else if sliceElementType.Kind() == reflect.Struct {
table = session.Engine.autoMapType(sliceElementType) pv := reflect.New(sliceElementType)
table = session.Engine.autoMapType(pv.Elem())
} else { } else {
return errors.New("slice type") return errors.New("slice type")
} }
@ -1386,13 +1388,12 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *T
} }
func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error {
dataStruct := rValue(bean)
dataStruct := reflect.Indirect(reflect.ValueOf(bean))
if dataStruct.Kind() != reflect.Struct { if dataStruct.Kind() != reflect.Struct {
return errors.New("Expected a pointer to a struct") return errors.New("Expected a pointer to a struct")
} }
table := session.Engine.autoMapType(rType(bean)) table := session.Engine.autoMapType(dataStruct)
var scanResultContainers []interface{} var scanResultContainers []interface{}
for i := 0; i < fieldsCount; i++ { for i := 0; i < fieldsCount; i++ {
@ -1494,7 +1495,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
fieldValue.Set(vv) fieldValue.Set(vv)
} }
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
table := session.Engine.autoMapType(fieldValue.Type()) table := session.Engine.autoMapType(*fieldValue)
if table != nil { if table != nil {
var x int64 var x int64
if rawValueType.Kind() == reflect.Int64 { if rawValueType.Kind() == reflect.Int64 {
@ -1763,9 +1764,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
} }
bean := sliceValue.Index(0).Interface() bean := sliceValue.Index(0).Interface()
sliceElementType := rType(bean) elementValue := rValue(bean)
//sliceElementType := elementValue.Type()
table := session.Engine.autoMapType(sliceElementType) table := session.Engine.autoMapType(elementValue)
session.Statement.RefTable = table session.Statement.RefTable = table
size := sliceValue.Len() size := sliceValue.Len()
@ -2073,7 +2075,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
v = x v = x
fieldValue.Set(reflect.ValueOf(v)) fieldValue.Set(reflect.ValueOf(v))
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
table := session.Engine.autoMapType(fieldValue.Type()) table := session.Engine.autoMapType(*fieldValue)
if table != nil { if table != nil {
x, err := strconv.ParseInt(string(data), 10, 64) x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil { if err != nil {

View File

@ -113,11 +113,12 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme
// tempororily set table name // tempororily set table name
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
t := rType(tableNameOrBean) v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.String { if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string) statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct { } else if t.Kind() == reflect.Struct {
statement.RefTable = statement.Engine.autoMapType(t) statement.RefTable = statement.Engine.autoMapType(v)
} }
return statement return statement
} }
@ -342,7 +343,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
val = t val = t
} }
} else { } else {
engine.autoMapType(fieldValue.Type()) engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok { if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 { if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)