use relfect.Type as key of tables map

This commit is contained in:
Lunny Xiao 2013-05-08 22:50:19 +08:00
parent 5870dbaab0
commit 4ebee70f92
4 changed files with 35 additions and 62 deletions

View File

@ -185,14 +185,14 @@ SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
} }
##Mapping Rules ##Mapping Rules
1.Struct 和 Struct 的field名字应该为Pascal式命名默认的映射规则将转换成用下划线连接的命名规则这个映射是自动进行的当然你可以通过修改Engine或者Session的成员IMapper来改变它。 1.Struct 和 Struct 的field名字应该为Pascal式命名默认的映射规则将转换成用下划线连接的命名规则这个映射是自动进行的当然你可以通过修改Engine的成员Mapper来改变它。
例如: 例如:
结构体的名字UserInfo将会自动对应数据库中的名为user_info的表。 结构体的名字UserInfo将会自动对应数据库中的名为user_info的表。
UserInfo中的成员UserName将会自动对应名为user_name的字段。 UserInfo中的成员UserName将会自动对应名为user_name的字段。
2.当然你也可以改变这个规则这有两种方法。一是实现你自己的IMapper你可以在mapper.go中查看到这个借口。然后设置到 engine.Mapper这将影响所有的Session或者你可以设置到某一个session那么只会影响到这个session对应的操作 2.当然你也可以改变这个规则这有两种方法。一是实现你自己的IMapper你可以在mapper.go中查看到这个接口。然后设置到 engine.Mapper
另外一种方法就通过Field Tag来进行改变关于Field Tag请参考Go的语言文档如下列出了Tag中可用的关键字及其对应的意义 另外一种方法就通过Field Tag来进行改变关于Field Tag请参考Go的语言文档如下列出了Tag中可用的关键字及其对应的意义

View File

@ -26,7 +26,7 @@ type Engine struct {
DBName string DBName string
Charset string Charset string
Others string Others string
Tables map[string]Table Tables map[reflect.Type]Table
AutoIncrement string AutoIncrement string
ShowSQL bool ShowSQL bool
QuoteIdentifier string QuoteIdentifier string
@ -38,6 +38,13 @@ func Type(bean interface{}) reflect.Type {
return reflect.TypeOf(sliceValue.Interface()) return reflect.TypeOf(sliceValue.Interface())
} }
func StructName(v reflect.Type) string {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v.Name()
}
func (e *Engine) OpenDB() (db *sql.DB, err error) { func (e *Engine) OpenDB() (db *sql.DB, err error) {
db = nil db = nil
err = nil err = nil
@ -78,7 +85,6 @@ func (engine *Engine) MakeSession() (session Session, err error) {
engine.QuoteIdentifier = "`" engine.QuoteIdentifier = "`"
session = Session{Engine: engine, Db: db} session = Session{Engine: engine, Db: db}
} }
session.Mapper = engine.Mapper
session.Init() session.Init()
return return
} }
@ -261,10 +267,9 @@ func (engine *Engine) MapType(t reflect.Type) Table {
func (engine *Engine) Map(beans ...interface{}) (e error) { func (engine *Engine) Map(beans ...interface{}) (e error) {
for _, bean := range beans { for _, bean := range beans {
tableName := engine.Mapper.Obj2Table(StructName(bean)) t := Type(bean)
if _, ok := engine.Tables[tableName]; !ok { if _, ok := engine.Tables[t]; !ok {
table := engine.MapOne(bean) engine.Tables[t] = engine.MapOne(bean)
engine.Tables[table.Name] = table
} }
} }
return return
@ -272,14 +277,19 @@ func (engine *Engine) Map(beans ...interface{}) (e error) {
func (engine *Engine) UnMap(beans ...interface{}) (e error) { func (engine *Engine) UnMap(beans ...interface{}) (e error) {
for _, bean := range beans { for _, bean := range beans {
tableName := engine.Mapper.Obj2Table(StructName(bean)) t := Type(bean)
if _, ok := engine.Tables[tableName]; ok { if _, ok := engine.Tables[t]; ok {
delete(engine.Tables, tableName) delete(engine.Tables, t)
} }
} }
return return
} }
func (engine *Engine) Bean2Table(bean interface{}) *Table {
table := engine.Tables[Type(bean)]
return &table
}
func (e *Engine) DropAll() error { func (e *Engine) DropAll() error {
session, err := e.MakeSession() session, err := e.MakeSession()
session.Begin() session.Begin()
@ -293,11 +303,10 @@ func (e *Engine) DropAll() error {
_, err = session.Exec(sql) _, err = session.Exec(sql)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
break
}
}
session.Commit()
return err return err
}
}
return session.Commit()
} }
func (e *Engine) CreateTables(beans ...interface{}) error { func (e *Engine) CreateTables(beans ...interface{}) error {
@ -309,16 +318,15 @@ func (e *Engine) CreateTables(beans ...interface{}) error {
} }
for _, bean := range beans { for _, bean := range beans {
table := e.MapOne(bean) table := e.MapOne(bean)
e.Tables[table.Name] = table e.Tables[table.Type] = table
sql := e.genCreateSQL(&table) sql := e.genCreateSQL(&table)
_, err = session.Exec(sql) _, err = session.Exec(sql)
if err != nil { if err != nil {
session.Rollback() session.Rollback()
break
}
}
session.Commit()
return err return err
}
}
return session.Commit()
} }
func (e *Engine) CreateAll() error { func (e *Engine) CreateAll() error {

View File

@ -10,36 +10,11 @@ import (
"time" "time"
) )
func getTypeName(obj interface{}) (typestr string) {
typ := reflect.TypeOf(obj)
typestr = typ.String()
lastDotIndex := strings.LastIndex(typestr, ".")
if lastDotIndex != -1 {
typestr = typestr[lastDotIndex+1:]
}
return
}
func StructName(s interface{}) string {
v := reflect.TypeOf(s)
return Type2StructName(v)
}
func Type2StructName(v reflect.Type) string {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v.Name()
}
type Session struct { type Session struct {
Db *sql.DB Db *sql.DB
Engine *Engine Engine *Engine
Tx *sql.Tx Tx *sql.Tx
Statement Statement Statement Statement
Mapper IMapper
IsAutoCommit bool IsAutoCommit bool
} }
@ -107,23 +82,13 @@ func (session *Session) Commit() error {
return session.Tx.Commit() return session.Tx.Commit()
} }
func (session *Session) TableName(bean interface{}) string {
return session.Mapper.Obj2Table(StructName(bean))
}
func (session *Session) Bean2Table(bean interface{}) *Table {
tablName := session.TableName(bean)
table := session.Engine.Tables[tablName]
return &table
}
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := reflect.Indirect(reflect.ValueOf(obj)) dataStruct := reflect.Indirect(reflect.ValueOf(obj))
if dataStruct.Kind() != reflect.Struct { if dataStruct.Kind() != reflect.Struct {
return errors.New("expected a pointer to a struct") return errors.New("expected a pointer to a struct")
} }
table := session.Bean2Table(obj) table := session.Engine.Bean2Table(obj)
for key, data := range objMap { for key, data := range objMap {
structField := dataStruct.FieldByName(table.Columns[key].FieldName) structField := dataStruct.FieldByName(table.Columns[key].FieldName)
@ -219,7 +184,7 @@ func (session *Session) Get(bean interface{}) error {
statement := session.Statement statement := session.Statement
defer session.Statement.Init() defer session.Statement.Init()
statement.Limit(1) statement.Limit(1)
table := session.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
statement.Table = table statement.Table = table
colNames, args := session.BuildConditions(table, bean) colNames, args := session.BuildConditions(table, bean)
@ -249,7 +214,7 @@ func (session *Session) Get(bean interface{}) error {
func (session *Session) Count(bean interface{}) (int64, error) { func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.Statement statement := session.Statement
defer session.Statement.Init() defer session.Statement.Init()
table := session.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
statement.Table = table statement.Table = table
colNames, args := session.BuildConditions(table, bean) colNames, args := session.BuildConditions(table, bean)
@ -280,8 +245,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
tableName := session.Mapper.Obj2Table(Type2StructName(sliceElementType)) table := session.Engine.Tables[sliceElementType]
table := session.Engine.Tables[tableName]
statement.Table = &table statement.Table = &table
if len(condiBean) > 0 { if len(condiBean) > 0 {
@ -389,7 +353,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
} }
func (session *Session) InsertOne(bean interface{}) (int64, error) { func (session *Session) InsertOne(bean interface{}) (int64, error) {
table := session.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
colNames := make([]string, 0) colNames := make([]string, 0)
colPlaces := make([]string, 0) colPlaces := make([]string, 0)
@ -460,7 +424,7 @@ func (session *Session) BuildConditions(table *Table, bean interface{}) ([]strin
} }
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
table := session.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
colNames, args := session.BuildConditions(table, bean) colNames, args := session.BuildConditions(table, bean)
var condiColNames []string var condiColNames []string
var condiArgs []interface{} var condiArgs []interface{}
@ -508,7 +472,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
func (session *Session) Delete(bean interface{}) (int64, error) { func (session *Session) Delete(bean interface{}) (int64, error) {
table := session.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
colNames, args := session.BuildConditions(table, bean) colNames, args := session.BuildConditions(table, bean)
var condition = "" var condition = ""

View File

@ -1,6 +1,7 @@
package xorm package xorm
import ( import (
"reflect"
"strings" "strings"
) )
@ -11,7 +12,7 @@ import (
func Create(schema string) Engine { func Create(schema string) Engine {
engine := Engine{} engine := Engine{}
engine.Mapper = SnakeMapper{} engine.Mapper = SnakeMapper{}
engine.Tables = make(map[string]Table) engine.Tables = make(map[reflect.Type]Table)
engine.Statement.Engine = &engine engine.Statement.Engine = &engine
l := strings.Split(schema, "://") l := strings.Split(schema, "://")
if len(l) == 2 { if len(l) == 2 {