From 901d7b66706faf08c4bd490815a3bc22b6441a46 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 25 Dec 2013 17:41:01 +0800 Subject: [PATCH] bug fixed --- base_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++++ engine.go | 4 +-- postgres_test.go | 8 ++++-- session.go | 10 +++++--- statement.go | 8 +++--- table.go | 8 +++--- 6 files changed, 89 insertions(+), 16 deletions(-) diff --git a/base_test.go b/base_test.go index d4c15f15..c940ab3a 100644 --- a/base_test.go +++ b/base_test.go @@ -3678,6 +3678,71 @@ func testCompositeKey(engine *Engine, t *testing.T) { } } +type User struct { + UserId string `xorm:"varchar(19) not null pk"` + NickName string `xorm:"varchar(19) not null"` + GameId uint32 `xorm:"integer pk"` + Score int32 `xorm:"integer"` +} + +func testCompositeKey2(engine *Engine, t *testing.T) { + + err := engine.DropTables(&User{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&User{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&User{"11", "nick", 22, 5}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("failed to insert User{11, 22}")) + } + + cnt, err = engine.Insert(&User{"11", "nick", 22, 6}) + if err == nil || cnt == 1 { + t.Error(errors.New("inserted User{11, 22}")) + } + + var user User + has, err := engine.Id(PK{"11", 22}).Get(&user) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get User{11, 22}")) + } + + // test passing PK ptr, this test seem failed withCache + has, err = engine.Id(&PK{"11", 22}).Get(&user) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get User{11, 22}")) + } + + user = User{NickName: "test1"} + cnt, err = engine.Id(PK{"11", 22}).Update(&user) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't update User{11, 22}")) + } + + cnt, err = engine.Id(PK{"11", 22}).Delete(&User{}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't delete CompositeKey{11, 22}")) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -3796,6 +3861,8 @@ func testAll3(engine *Engine, t *testing.T) { testNullValue(engine, t) fmt.Println("-------------- testCompositeKey --------------") testCompositeKey(engine, t) + fmt.Println("-------------- testCompositeKey2 --------------") + testCompositeKey2(engine, t) fmt.Println("-------------- testStringPK --------------") testStringPK(engine, t) } diff --git a/engine.go b/engine.go index ee31e47c..31fdd62b 100644 --- a/engine.go +++ b/engine.go @@ -472,7 +472,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { parentTable := engine.mapType(fieldType) for name, col := range parentTable.Columns { col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) - table.Columns[name] = col + table.Columns[strings.ToLower(name)] = col table.ColumnsSeq = append(table.ColumnsSeq, name) } @@ -603,7 +603,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } if idFieldColName != "" && len(table.PrimaryKeys) == 0 { - col := table.Columns[idFieldColName] + col := table.Columns[strings.ToLower(idFieldColName)] col.IsPrimaryKey = true col.IsAutoIncrement = true col.Nullable = false diff --git a/postgres_test.go b/postgres_test.go index 907b4e15..9657cffd 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -7,12 +7,16 @@ import ( _ "github.com/lib/pq" ) +//var connStr string = "dbname=xorm_test user=lunny password=1234 sslmode=disable" + +var connStr string = "dbname=xorm_test sslmode=disable" + func newPostgresEngine() (*Engine, error) { - return NewEngine("postgres", "dbname=xorm_test user=lunny password=1234 sslmode=disable") + return NewEngine("postgres", connStr) } func newPostgresDriverDB() (*sql.DB, error) { - return sql.Open("postgres", "dbname=xorm_test user=lunny password=1234 sslmode=disable") + return sql.Open("postgres", connStr) } func TestPostgres(t *testing.T) { diff --git a/session.go b/session.go index 03f1c637..bdaa8658 100644 --- a/session.go +++ b/session.go @@ -360,11 +360,13 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b for key, data := range objMap { key = strings.ToLower(key) - if _, ok := table.Columns[key]; !ok { + var col *Column + var ok bool + if col, ok = table.Columns[key]; !ok { session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) continue } - col := table.Columns[key] + fieldName := col.FieldName fieldPath := strings.Split(fieldName, ".") var fieldValue reflect.Value @@ -1197,7 +1199,7 @@ func (session *Session) addColumn(colName string) error { defer session.Close() } //fmt.Println(session.Statement.RefTable) - col := session.Statement.RefTable.Columns[colName] + col := session.Statement.RefTable.Columns[strings.ToLower(colName)] sql, args := session.Statement.genAddColumnStr(col) _, err = session.exec(sql, args...) return err @@ -2470,7 +2472,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { return ErrCacheFailed } - if col, ok := table.Columns[colName]; ok { + if col, ok := table.Columns[strings.ToLower(colName)]; ok { fieldValue := col.ValueOf(bean) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) if col.IsVersion && session.Statement.checkVersion { diff --git a/statement.go b/statement.go index 18834c1a..f87c0c5d 100644 --- a/statement.go +++ b/statement.go @@ -610,7 +610,7 @@ func (statement *Statement) genCreateTableSQL() string { pkList := []string{} for _, colName := range statement.RefTable.ColumnsSeq { - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey { pkList = append(pkList, col.Name) } @@ -618,7 +618,7 @@ func (statement *Statement) genCreateTableSQL() string { statement.Engine.LogDebug("len:", len(pkList)) for _, colName := range statement.RefTable.ColumnsSeq { - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(statement.Engine.dialect) } else { @@ -823,7 +823,7 @@ func (statement *Statement) processIdParam() { for _, elem := range *(statement.IdParam) { for ; i < colCnt; i++ { colName := statement.RefTable.ColumnsSeq[i] - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey { statement.And(fmt.Sprintf("%v=?", col.Name), elem) i++ @@ -837,7 +837,7 @@ func (statement *Statement) processIdParam() { // false update/delete for ; i < colCnt; i++ { colName := statement.RefTable.ColumnsSeq[i] - col := statement.RefTable.Columns[colName] + col := statement.RefTable.Columns[strings.ToLower(colName)] if col.IsPrimaryKey { statement.And(fmt.Sprintf("%v=?", col.Name), "") } diff --git a/table.go b/table.go index 1eee795f..84ba6e53 100644 --- a/table.go +++ b/table.go @@ -366,23 +366,23 @@ func NewTable(name string, t reflect.Type) *Table { func (table *Table) PKColumns() []*Column { columns := make([]*Column, 0) for _, name := range table.PrimaryKeys { - columns = append(columns, table.Columns[name]) + columns = append(columns, table.Columns[strings.ToLower(name)]) } return columns } func (table *Table) AutoIncrColumn() *Column { - return table.Columns[table.AutoIncrement] + return table.Columns[strings.ToLower(table.AutoIncrement)] } func (table *Table) VersionColumn() *Column { - return table.Columns[table.Version] + return table.Columns[strings.ToLower(table.Version)] } // add a column to table func (table *Table) AddColumn(col *Column) { table.ColumnsSeq = append(table.ColumnsSeq, col.Name) - table.Columns[col.Name] = col + table.Columns[strings.ToLower(col.Name)] = col if col.IsPrimaryKey { table.PrimaryKeys = append(table.PrimaryKeys, col.Name) }