diff --git a/base_test.go b/base_test.go index 3199b850..f8d9d794 100644 --- a/base_test.go +++ b/base_test.go @@ -3990,6 +3990,28 @@ func testCompositeKey2(engine *Engine, t *testing.T) { } } +type CustomTableName struct { + Id int64 + Name string +} + +func (c *CustomTableName) TableName() string { + return "customtablename" +} + +func testCustomTableName(engine *Engine, t *testing.T) { + c := new(CustomTableName) + err := engine.DropTables(c) + if err != nil { + t.Error(err) + } + + err = engine.CreateTables(c) + if err != nil { + t.Error(err) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -4100,6 +4122,8 @@ func testAll2(engine *Engine, t *testing.T) { testProcessors(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) + fmt.Println("-------------- testCustomTableName --------------") + testCustomTableName(engine, t) } // !nash! the 3rd set of the test is intended for non-cache enabled engine diff --git a/engine.go b/engine.go index 070c718c..2499d494 100644 --- a/engine.go +++ b/engine.go @@ -157,9 +157,9 @@ func (engine *Engine) NoCascade() *Session { // Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { - t := rType(bean) - engine.autoMapType(t) - engine.Tables[t].Cacher = cacher + v := rValue(bean) + engine.autoMapType(v) + engine.Tables[v.Type()].Cacher = cacher } // OpenDB provides a interface to operate database directly. @@ -435,12 +435,13 @@ func (engine *Engine) Having(conditions string) *Session { return session.Having(conditions) } -func (engine *Engine) autoMapType(t reflect.Type) *Table { +func (engine *Engine) autoMapType(v reflect.Value) *Table { + t := v.Type() engine.mutex.RLock() table, ok := engine.Tables[t] engine.mutex.RUnlock() if !ok { - table = engine.mapType(t) + table = engine.mapType(v) engine.mutex.Lock() engine.Tables[t] = table engine.mutex.Unlock() @@ -449,8 +450,8 @@ func (engine *Engine) autoMapType(t reflect.Type) *Table { } func (engine *Engine) autoMap(bean interface{}) *Table { - t := rType(bean) - return engine.autoMapType(t) + v := rValue(bean) + return engine.autoMapType(v) } func (engine *Engine) newTable() *Table { @@ -475,9 +476,24 @@ func addIndex(indexName string, table *Table, col *Column, indexType int) { } } -func (engine *Engine) mapType(t reflect.Type) *Table { +func (engine *Engine) mapType(v reflect.Value) *Table { + t := v.Type() table := engine.newTable() - table.Name = engine.tableMapper.Obj2Table(t.Name()) + method := v.MethodByName("TableName") + if !method.IsValid() { + method = v.Addr().MethodByName("TableName") + } + if method.IsValid() { + params := []reflect.Value{} + results := method.Call(params) + if len(results) == 1 { + table.Name = results[0].Interface().(string) + } + } + + if table.Name == "" { + table.Name = engine.tableMapper.Obj2Table(t.Name()) + } table.Type = t var idFieldColName string @@ -487,7 +503,8 @@ func (engine *Engine) mapType(t reflect.Type) *Table { tag := t.Field(i).Tag ormTagStr := tag.Get(engine.TagIdentifier) var col *Column - fieldType := t.Field(i).Type + fieldValue := v.Field(i) + fieldType := fieldValue.Type() if ormTagStr != "" { col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, @@ -500,7 +517,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } if (strings.ToUpper(tags[0]) == "EXTENDS") && (fieldType.Kind() == reflect.Struct) { - parentTable := engine.mapType(fieldType) + parentTable := engine.mapType(fieldValue) for name, col := range parentTable.Columns { col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) table.Columns[strings.ToLower(name)] = col @@ -671,19 +688,20 @@ func (engine *Engine) mapping(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() for _, bean := range beans { - t := rType(bean) - engine.Tables[t] = engine.mapType(t) + v := rValue(bean) + engine.Tables[v.Type()] = engine.mapType(v) } return } // If a table has any reocrd func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { - t := rType(bean) + v := rValue(bean) + t := v.Type() if t.Kind() != reflect.Struct { return false, errors.New("bean should be a struct or struct's point") } - engine.autoMapType(t) + engine.autoMapType(v) session := engine.NewSession() defer session.Close() rows, err := session.Count(bean) @@ -692,11 +710,11 @@ func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { // If a table is exist func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { - t := rType(bean) - if t.Kind() != reflect.Struct { + v := rValue(bean) + if v.Type().Kind() != reflect.Struct { return false, errors.New("bean should be a struct or struct's point") } - table := engine.autoMapType(t) + table := engine.autoMapType(v) session := engine.NewSession() defer session.Close() has, err := session.isTableExist(table.Name) diff --git a/helpers.go b/helpers.go index 96f118f2..25b6ddc8 100644 --- a/helpers.go +++ b/helpers.go @@ -37,9 +37,14 @@ func makeArray(elem string, count int) []string { return res } +func rValue(bean interface{}) reflect.Value { + return reflect.Indirect(reflect.ValueOf(bean)) +} + func rType(bean interface{}) reflect.Type { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - return reflect.TypeOf(sliceValue.Interface()) + //return reflect.TypeOf(sliceValue.Interface()) + return sliceValue.Type() } func structName(v reflect.Type) string { diff --git a/session.go b/session.go index bc718b9b..b8035acc 100644 --- a/session.go +++ b/session.go @@ -358,12 +358,12 @@ func cleanupProcessorsClosures(slices *[]func(interface{})) { } func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { - dataStruct := reflect.Indirect(reflect.ValueOf(obj)) + dataStruct := rValue(obj) if dataStruct.Kind() != reflect.Struct { return errors.New("Expected a pointer to a struct") } - table := session.Engine.autoMapType(rType(obj)) + table := session.Engine.autoMapType(dataStruct) for key, data := range objMap { key = strings.ToLower(key) @@ -1017,12 +1017,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if session.Statement.RefTable == nil { if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Elem().Kind() == reflect.Struct { - table = session.Engine.autoMapType(sliceElementType.Elem()) + pv := reflect.New(sliceElementType.Elem()) + table = session.Engine.autoMapType(pv.Elem()) } else { return errors.New("slice type") } } else if sliceElementType.Kind() == reflect.Struct { - table = session.Engine.autoMapType(sliceElementType) + pv := reflect.New(sliceElementType) + table = session.Engine.autoMapType(pv.Elem()) } else { return errors.New("slice type") } @@ -1386,13 +1388,12 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *T } func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { - - dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + dataStruct := rValue(bean) if dataStruct.Kind() != reflect.Struct { return errors.New("Expected a pointer to a struct") } - table := session.Engine.autoMapType(rType(bean)) + table := session.Engine.autoMapType(dataStruct) var scanResultContainers []interface{} for i := 0; i < fieldsCount; i++ { @@ -1494,7 +1495,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in fieldValue.Set(vv) } } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(fieldValue.Type()) + table := session.Engine.autoMapType(*fieldValue) if table != nil { var x int64 if rawValueType.Kind() == reflect.Int64 { @@ -1763,9 +1764,10 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } bean := sliceValue.Index(0).Interface() - sliceElementType := rType(bean) + elementValue := rValue(bean) + //sliceElementType := elementValue.Type() - table := session.Engine.autoMapType(sliceElementType) + table := session.Engine.autoMapType(elementValue) session.Statement.RefTable = table size := sliceValue.Len() @@ -2073,7 +2075,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data v = x fieldValue.Set(reflect.ValueOf(v)) } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(fieldValue.Type()) + table := session.Engine.autoMapType(*fieldValue) if table != nil { x, err := strconv.ParseInt(string(data), 10, 64) if err != nil { diff --git a/statement.go b/statement.go index 805f5955..773dd378 100644 --- a/statement.go +++ b/statement.go @@ -113,11 +113,12 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme // tempororily set table name func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - t := rType(tableNameOrBean) + v := rValue(tableNameOrBean) + t := v.Type() if t.Kind() == reflect.String { statement.AltTableName = tableNameOrBean.(string) } else if t.Kind() == reflect.Struct { - statement.RefTable = statement.Engine.autoMapType(t) + statement.RefTable = statement.Engine.autoMapType(v) } return statement } @@ -342,7 +343,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, val = t } } else { - engine.autoMapType(fieldValue.Type()) + engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { if len(table.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)