From 4ebee70f92ec8b7b64b086c5a744dad1ec2cab73 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 8 May 2013 22:50:19 +0800 Subject: [PATCH] use relfect.Type as key of tables map --- README_CN.md | 4 ++-- engine.go | 40 ++++++++++++++++++++++++---------------- session.go | 50 +++++++------------------------------------------- xorm.go | 3 ++- 4 files changed, 35 insertions(+), 62 deletions(-) diff --git a/README_CN.md b/README_CN.md index 1b9ab420..7fc5a164 100644 --- a/README_CN.md +++ b/README_CN.md @@ -185,14 +185,14 @@ SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) } ##Mapping Rules -1.Struct 和 Struct 的field名字应该为Pascal式命名,默认的映射规则将转换成用下划线连接的命名规则,这个映射是自动进行的,当然,你可以通过修改Engine或者Session的成员IMapper来改变它。 +1.Struct 和 Struct 的field名字应该为Pascal式命名,默认的映射规则将转换成用下划线连接的命名规则,这个映射是自动进行的,当然,你可以通过修改Engine的成员Mapper来改变它。 例如: 结构体的名字UserInfo将会自动对应数据库中的名为user_info的表。 UserInfo中的成员UserName将会自动对应名为user_name的字段。 -2.当然你也可以改变这个规则,这有两种方法。一是实现你自己的IMapper,你可以在mapper.go中查看到这个借口。然后设置到 engine.Mapper,这将影响所有的Session,或者你可以设置到某一个session,那么只会影响到这个session对应的操作。 +2.当然你也可以改变这个规则,这有两种方法。一是实现你自己的IMapper,你可以在mapper.go中查看到这个接口。然后设置到 engine.Mapper。 另外一种方法就通过Field Tag来进行改变,关于Field Tag请参考Go的语言文档,如下列出了Tag中可用的关键字及其对应的意义: diff --git a/engine.go b/engine.go index e22dc9df..84cff18a 100644 --- a/engine.go +++ b/engine.go @@ -26,7 +26,7 @@ type Engine struct { DBName string Charset string Others string - Tables map[string]Table + Tables map[reflect.Type]Table AutoIncrement string ShowSQL bool QuoteIdentifier string @@ -38,6 +38,13 @@ func Type(bean interface{}) reflect.Type { return reflect.TypeOf(sliceValue.Interface()) } +func StructName(v reflect.Type) string { + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Name() +} + func (e *Engine) OpenDB() (db *sql.DB, err error) { db = nil err = nil @@ -78,7 +85,6 @@ func (engine *Engine) MakeSession() (session Session, err error) { engine.QuoteIdentifier = "`" session = Session{Engine: engine, Db: db} } - session.Mapper = engine.Mapper session.Init() return } @@ -261,10 +267,9 @@ func (engine *Engine) MapType(t reflect.Type) Table { func (engine *Engine) Map(beans ...interface{}) (e error) { for _, bean := range beans { - tableName := engine.Mapper.Obj2Table(StructName(bean)) - if _, ok := engine.Tables[tableName]; !ok { - table := engine.MapOne(bean) - engine.Tables[table.Name] = table + t := Type(bean) + if _, ok := engine.Tables[t]; !ok { + engine.Tables[t] = engine.MapOne(bean) } } return @@ -272,14 +277,19 @@ func (engine *Engine) Map(beans ...interface{}) (e error) { func (engine *Engine) UnMap(beans ...interface{}) (e error) { for _, bean := range beans { - tableName := engine.Mapper.Obj2Table(StructName(bean)) - if _, ok := engine.Tables[tableName]; ok { - delete(engine.Tables, tableName) + t := Type(bean) + if _, ok := engine.Tables[t]; ok { + delete(engine.Tables, t) } } return } +func (engine *Engine) Bean2Table(bean interface{}) *Table { + table := engine.Tables[Type(bean)] + return &table +} + func (e *Engine) DropAll() error { session, err := e.MakeSession() session.Begin() @@ -293,11 +303,10 @@ func (e *Engine) DropAll() error { _, err = session.Exec(sql) if err != nil { session.Rollback() - break + return err } } - session.Commit() - return err + return session.Commit() } func (e *Engine) CreateTables(beans ...interface{}) error { @@ -309,16 +318,15 @@ func (e *Engine) CreateTables(beans ...interface{}) error { } for _, bean := range beans { table := e.MapOne(bean) - e.Tables[table.Name] = table + e.Tables[table.Type] = table sql := e.genCreateSQL(&table) _, err = session.Exec(sql) if err != nil { session.Rollback() - break + return err } } - session.Commit() - return err + return session.Commit() } func (e *Engine) CreateAll() error { diff --git a/session.go b/session.go index c2c9a22a..11852f8d 100644 --- a/session.go +++ b/session.go @@ -10,36 +10,11 @@ import ( "time" ) -func getTypeName(obj interface{}) (typestr string) { - typ := reflect.TypeOf(obj) - typestr = typ.String() - - lastDotIndex := strings.LastIndex(typestr, ".") - if lastDotIndex != -1 { - typestr = typestr[lastDotIndex+1:] - } - - return -} - -func StructName(s interface{}) string { - v := reflect.TypeOf(s) - return Type2StructName(v) -} - -func Type2StructName(v reflect.Type) string { - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - return v.Name() -} - type Session struct { Db *sql.DB Engine *Engine Tx *sql.Tx Statement Statement - Mapper IMapper IsAutoCommit bool } @@ -107,23 +82,13 @@ func (session *Session) Commit() error { return session.Tx.Commit() } -func (session *Session) TableName(bean interface{}) string { - return session.Mapper.Obj2Table(StructName(bean)) -} - -func (session *Session) Bean2Table(bean interface{}) *Table { - tablName := session.TableName(bean) - table := session.Engine.Tables[tablName] - return &table -} - func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { dataStruct := reflect.Indirect(reflect.ValueOf(obj)) if dataStruct.Kind() != reflect.Struct { return errors.New("expected a pointer to a struct") } - table := session.Bean2Table(obj) + table := session.Engine.Bean2Table(obj) for key, data := range objMap { structField := dataStruct.FieldByName(table.Columns[key].FieldName) @@ -219,7 +184,7 @@ func (session *Session) Get(bean interface{}) error { statement := session.Statement defer session.Statement.Init() statement.Limit(1) - table := session.Bean2Table(bean) + table := session.Engine.Bean2Table(bean) statement.Table = table colNames, args := session.BuildConditions(table, bean) @@ -249,7 +214,7 @@ func (session *Session) Get(bean interface{}) error { func (session *Session) Count(bean interface{}) (int64, error) { statement := session.Statement defer session.Statement.Init() - table := session.Bean2Table(bean) + table := session.Engine.Bean2Table(bean) statement.Table = table colNames, args := session.BuildConditions(table, bean) @@ -280,8 +245,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() - tableName := session.Mapper.Obj2Table(Type2StructName(sliceElementType)) - table := session.Engine.Tables[tableName] + table := session.Engine.Tables[sliceElementType] statement.Table = &table if len(condiBean) > 0 { @@ -389,7 +353,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { } func (session *Session) InsertOne(bean interface{}) (int64, error) { - table := session.Bean2Table(bean) + table := session.Engine.Bean2Table(bean) colNames := make([]string, 0) colPlaces := make([]string, 0) @@ -460,7 +424,7 @@ func (session *Session) BuildConditions(table *Table, bean interface{}) ([]strin } func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { - table := session.Bean2Table(bean) + table := session.Engine.Bean2Table(bean) colNames, args := session.BuildConditions(table, bean) var condiColNames []string var condiArgs []interface{} @@ -508,7 +472,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } func (session *Session) Delete(bean interface{}) (int64, error) { - table := session.Bean2Table(bean) + table := session.Engine.Bean2Table(bean) colNames, args := session.BuildConditions(table, bean) var condition = "" diff --git a/xorm.go b/xorm.go index d79a0d42..1e913a45 100644 --- a/xorm.go +++ b/xorm.go @@ -1,6 +1,7 @@ package xorm import ( + "reflect" "strings" ) @@ -11,7 +12,7 @@ import ( func Create(schema string) Engine { engine := Engine{} engine.Mapper = SnakeMapper{} - engine.Tables = make(map[string]Table) + engine.Tables = make(map[reflect.Type]Table) engine.Statement.Engine = &engine l := strings.Split(schema, "://") if len(l) == 2 {