diff --git a/.gitignore b/.gitignore index 4b502a3c..cd5e8f73 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ vendor *.log .vendor +temp_test.go diff --git a/base_test.go b/base_test.go index 4067185f..6816fc70 100644 --- a/base_test.go +++ b/base_test.go @@ -1,11 +1,11 @@ package xorm import ( - "errors" - "fmt" - "strings" - "testing" - "time" + "errors" + "fmt" + "strings" + "testing" + "time" ) /* @@ -25,1193 +25,1193 @@ CREATE TABLE `userdeatail` ( */ type Userinfo struct { - Uid int64 `xorm:"id pk not null autoincr"` - Username string `xorm:"unique"` - Departname string - Alias string `xorm:"-"` - Created time.Time - Detail Userdetail `xorm:"detail_id int(11)"` - Height float64 - Avatar []byte - IsMan bool + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool } type Userdetail struct { - Id int64 - Intro string `xorm:"text"` - Profile string `xorm:"varchar(2000)"` + Id int64 + Intro string `xorm:"text"` + Profile string `xorm:"varchar(2000)"` } func directCreateTable(engine *Engine, t *testing.T) { - err := engine.DropTables(&Userinfo{}, &Userdetail{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&Userinfo{}, &Userdetail{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.Sync(&Userinfo{}, &Userdetail{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.Sync(&Userinfo{}, &Userdetail{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.DropTables(&Userinfo{}, &Userdetail{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.DropTables(&Userinfo{}, &Userdetail{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&Userinfo{}, &Userdetail{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&Userinfo{}, &Userdetail{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateIndexes(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateIndexes(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateIndexes(&Userdetail{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateIndexes(&Userdetail{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateUniques(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateUniques(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateUniques(&Userdetail{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateUniques(&Userdetail{}) + if err != nil { + t.Error(err) + panic(err) + } } func insert(engine *Engine, t *testing.T) { - user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), - Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} - cnt, err := engine.Insert(&user) - fmt.Println(user.Uid) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } - if user.Uid <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } + user := Userinfo{0, "xiaolunwen", "dev", "lunny", time.Now(), + Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} + cnt, err := engine.Insert(&user) + fmt.Println(user.Uid) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if user.Uid <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } - user.Uid = 0 - cnt, err = engine.Insert(&user) - if err == nil { - err = errors.New("insert failed but no return error") - t.Error(err) - panic(err) - } - if cnt != 0 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + user.Uid = 0 + cnt, err = engine.Insert(&user) + if err == nil { + err = errors.New("insert failed but no return error") + t.Error(err) + panic(err) + } + if cnt != 0 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } } func testQuery(engine *Engine, t *testing.T) { - sql := "select * from userinfo" - results, err := engine.Query(sql) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(results) + sql := "select * from userinfo" + results, err := engine.Query(sql) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(results) } func exec(engine *Engine, t *testing.T) { - sql := "update userinfo set username=? where id=?" - res, err := engine.Exec(sql, "xiaolun", 1) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(res) + sql := "update userinfo set username=? where id=?" + res, err := engine.Exec(sql, "xiaolun", 1) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(res) } func querySameMapper(engine *Engine, t *testing.T) { - sql := "select * from `Userinfo`" - results, err := engine.Query(sql) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(results) + sql := "select * from `Userinfo`" + results, err := engine.Query(sql) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(results) } func execSameMapper(engine *Engine, t *testing.T) { - sql := "update `Userinfo` set `Username`=? where (id)=?" - res, err := engine.Exec(sql, "xiaolun", 1) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(res) + sql := "update `Userinfo` set `Username`=? where (id)=?" + res, err := engine.Exec(sql, "xiaolun", 1) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(res) } func insertAutoIncr(engine *Engine, t *testing.T) { - // auto increment insert - user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), - Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} - cnt, err := engine.Insert(&user) - fmt.Println(user.Uid) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } - if user.Uid <= 0 { - t.Error(errors.New("not return id error")) - } + // auto increment insert + user := Userinfo{Username: "xiaolunwen2", Departname: "dev", Alias: "lunny", Created: time.Now(), + Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} + cnt, err := engine.Insert(&user) + fmt.Println(user.Uid) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if user.Uid <= 0 { + t.Error(errors.New("not return id error")) + } } func insertMulti(engine *Engine, t *testing.T) { - //engine.InsertMany = true - users := []Userinfo{ - {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - {Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - } - cnt, err := engine.Insert(&users) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != int64(len(users)) { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + //engine.InsertMany = true + users := []Userinfo{ + {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + cnt, err := engine.Insert(&users) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != int64(len(users)) { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } - users2 := []*Userinfo{ - &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - &Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, - &Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, - } + users2 := []*Userinfo{ + &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + &Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } - cnt, err = engine.Insert(&users2) - if err != nil { - t.Error(err) - panic(err) - } + cnt, err = engine.Insert(&users2) + if err != nil { + t.Error(err) + panic(err) + } - if cnt != int64(len(users2)) { - err = errors.New(fmt.Sprintf("insert not returned %v", len(users2))) - t.Error(err) - panic(err) - return - } + if cnt != int64(len(users2)) { + err = errors.New(fmt.Sprintf("insert not returned %v", len(users2))) + t.Error(err) + panic(err) + return + } } func insertTwoTable(engine *Engine, t *testing.T) { - userdetail := Userdetail{Id: 1, Intro: "I'm a very beautiful women.", Profile: "sfsaf"} - userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} + userdetail := Userdetail{Id: 1, Intro: "I'm a very beautiful women.", Profile: "sfsaf"} + userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} - cnt, err := engine.Insert(&userinfo, &userdetail) - if err != nil { - t.Error(err) - panic(err) - } + cnt, err := engine.Insert(&userinfo, &userdetail) + if err != nil { + t.Error(err) + panic(err) + } - if userinfo.Uid <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } + if userinfo.Uid <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } - if userdetail.Id <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } + if userdetail.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } - if cnt != 2 { - err = errors.New("insert not returned 2") - t.Error(err) - panic(err) - return - } + if cnt != 2 { + err = errors.New("insert not returned 2") + t.Error(err) + panic(err) + return + } } type Condi map[string]interface{} func update(engine *Engine, t *testing.T) { - // update by id - user := Userinfo{Username: "xxx", Height: 1.2} - cnt, err := engine.Id(1).Update(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + // update by id + user := Userinfo{Username: "xxx", Height: 1.2} + cnt, err := engine.Id(1).Update(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } - condi := Condi{"username": "zzz", "height": 0.0, "departname": ""} - cnt, err = engine.Table(&user).Id(1).Update(&condi) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + condi := Condi{"username": "zzz", "height": 0.0, "departname": ""} + cnt, err = engine.Table(&user).Id(1).Update(&condi) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } - cnt, err = engine.Update(&Userinfo{Username: "yyy"}, &user) - if err != nil { - t.Error(err) - panic(err) - } - total, err := engine.Count(&user) - if err != nil { - t.Error(err) - panic(err) - } + cnt, err = engine.Update(&Userinfo{Username: "yyy"}, &user) + if err != nil { + t.Error(err) + panic(err) + } + total, err := engine.Count(&user) + if err != nil { + t.Error(err) + panic(err) + } - if cnt != total { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + if cnt != total { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } } func updateSameMapper(engine *Engine, t *testing.T) { - // update by id - user := Userinfo{Username: "xxx", Height: 1.2} - cnt, err := engine.Id(1).Update(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + // update by id + user := Userinfo{Username: "xxx", Height: 1.2} + cnt, err := engine.Id(1).Update(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } - condi := Condi{"Username": "zzz", "Height": 0.0, "Departname": ""} - cnt, err = engine.Table(&user).Id(1).Update(&condi) - if err != nil { - t.Error(err) - panic(err) - } + condi := Condi{"Username": "zzz", "Height": 0.0, "Departname": ""} + cnt, err = engine.Table(&user).Id(1).Update(&condi) + if err != nil { + t.Error(err) + panic(err) + } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } - cnt, err = engine.Update(&Userinfo{Username: "yyy"}, &user) - if err != nil { - t.Error(err) - panic(err) - } + cnt, err = engine.Update(&Userinfo{Username: "yyy"}, &user) + if err != nil { + t.Error(err) + panic(err) + } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } } func testdelete(engine *Engine, t *testing.T) { - user := Userinfo{Uid: 1} - cnt, err := engine.Delete(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("delete not returned 1") - t.Error(err) - panic(err) - return - } + user := Userinfo{Uid: 1} + cnt, err := engine.Delete(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("delete not returned 1") + t.Error(err) + panic(err) + return + } - user.Uid = 0 - has, err := engine.Id(3).Get(&user) - if err != nil { - t.Error(err) - panic(err) - } + user.Uid = 0 + has, err := engine.Id(3).Get(&user) + if err != nil { + t.Error(err) + panic(err) + } - if has { - //var tt time.Time - //user.Created = tt - cnt, err := engine.UseBool().Delete(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - t.Error(errors.New("delete failed")) - panic(err) - } - } + if has { + //var tt time.Time + //user.Created = tt + cnt, err := engine.UseBool().Delete(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + t.Error(errors.New("delete failed")) + panic(err) + } + } } type NoIdUser struct { - User string `xorm:"unique"` - Remain int64 - Total int64 + User string `xorm:"unique"` + Remain int64 + Total int64 } func get(engine *Engine, t *testing.T) { - user := Userinfo{Uid: 2} + user := Userinfo{Uid: 2} - has, err := engine.Get(&user) - if err != nil { - t.Error(err) - panic(err) - } - if has { - fmt.Println(user) - } else { - fmt.Println("no record id is 2") - } + has, err := engine.Get(&user) + if err != nil { + t.Error(err) + panic(err) + } + if has { + fmt.Println(user) + } else { + fmt.Println("no record id is 2") + } - err = engine.Sync(&NoIdUser{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.Sync(&NoIdUser{}) + if err != nil { + t.Error(err) + panic(err) + } - _, err = engine.Where("`user` = ?", "xlw").Delete(&NoIdUser{}) - if err != nil { - t.Error(err) - panic(err) - } + _, err = engine.Where("`user` = ?", "xlw").Delete(&NoIdUser{}) + if err != nil { + t.Error(err) + panic(err) + } - cnt, err := engine.Insert(&NoIdUser{"xlw", 20, 100}) - if err != nil { - t.Error(err) - panic(err) - } + cnt, err := engine.Insert(&NoIdUser{"xlw", 20, 100}) + if err != nil { + t.Error(err) + panic(err) + } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + } - noIdUser := new(NoIdUser) - has, err = engine.Where("`user` = ?", "xlw").Get(noIdUser) - if err != nil { - t.Error(err) - panic(err) - } + noIdUser := new(NoIdUser) + has, err = engine.Where("`user` = ?", "xlw").Get(noIdUser) + if err != nil { + t.Error(err) + panic(err) + } - if !has { - err = errors.New("get not returned 1") - t.Error(err) - panic(err) - } - fmt.Println(noIdUser) + if !has { + err = errors.New("get not returned 1") + t.Error(err) + panic(err) + } + fmt.Println(noIdUser) } func cascadeGet(engine *Engine, t *testing.T) { - user := Userinfo{Uid: 11} + user := Userinfo{Uid: 11} - has, err := engine.Get(&user) - if err != nil { - t.Error(err) - panic(err) - } - if has { - fmt.Println(user) - } else { - fmt.Println("no record id is 2") - } + has, err := engine.Get(&user) + if err != nil { + t.Error(err) + panic(err) + } + if has { + fmt.Println(user) + } else { + fmt.Println("no record id is 2") + } } func find(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) + users := make([]Userinfo, 0) - err := engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - fmt.Println(user) - } + err := engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + fmt.Println(user) + } - users2 := make([]Userinfo, 0) - err = engine.Sql("select * from userinfo").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + users2 := make([]Userinfo, 0) + err = engine.Sql("select * from userinfo").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } } func find2(engine *Engine, t *testing.T) { - users := make([]*Userinfo, 0) + users := make([]*Userinfo, 0) - err := engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - fmt.Println(user) - } + err := engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + fmt.Println(user) + } } func findMap(engine *Engine, t *testing.T) { - users := make(map[int64]Userinfo) + users := make(map[int64]Userinfo) - err := engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - fmt.Println(user) - } + err := engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + fmt.Println(user) + } } func findMap2(engine *Engine, t *testing.T) { - users := make(map[int64]*Userinfo) + users := make(map[int64]*Userinfo) - err := engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for id, user := range users { - fmt.Println(id, user) - } + err := engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for id, user := range users { + fmt.Println(id, user) + } } func count(engine *Engine, t *testing.T) { - user := Userinfo{Departname: "dev"} - total, err := engine.Count(&user) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Printf("Total %d records!!!\n", total) + user := Userinfo{Departname: "dev"} + total, err := engine.Count(&user) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Printf("Total %d records!!!\n", total) } func where(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.Where("(id) > ?", 2).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.Where("(id) > ?", 2).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) - err = engine.Where("(id) > ?", 2).And("(id) < ?", 10).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + err = engine.Where("(id) > ?", 2).And("(id) < ?", 10).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) } func in(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.In("(id)", 1, 2, 3).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.In("(id)", 1, 2, 3).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) - ids := []interface{}{1, 2, 3} - err = engine.Where("(id) > ?", 2).In("(id)", ids...).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + ids := []interface{}{1, 2, 3} + err = engine.Where("(id) > ?", 2).In("(id)", ids...).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) - err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + err = engine.In("(id)", 1).In("(id)", 2).In("departname", "dev").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) - cnt, err := engine.In("(id)", 4).Update(&Userinfo{Departname: "dev-"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update records not 1") - t.Error(err) - panic(err) - } + cnt, err := engine.In("(id)", 4).Update(&Userinfo{Departname: "dev-"}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update records not 1") + t.Error(err) + panic(err) + } - user := new(Userinfo) - has, err := engine.Id(4).Get(user) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get record not 1") - t.Error(err) - panic(err) - } - if user.Departname != "dev-" { - err = errors.New("update not success") - t.Error(err) - panic(err) - } + user := new(Userinfo) + has, err := engine.Id(4).Get(user) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get record not 1") + t.Error(err) + panic(err) + } + if user.Departname != "dev-" { + err = errors.New("update not success") + t.Error(err) + panic(err) + } - cnt, err = engine.In("(id)", 4).Update(&Userinfo{Departname: "dev"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update records not 1") - t.Error(err) - panic(err) - } + cnt, err = engine.In("(id)", 4).Update(&Userinfo{Departname: "dev"}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update records not 1") + t.Error(err) + panic(err) + } - cnt, err = engine.In("(id)", 5).Delete(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("deleted records not 1") - t.Error(err) - panic(err) - } + cnt, err = engine.In("(id)", 5).Delete(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("deleted records not 1") + t.Error(err) + panic(err) + } } func limit(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.Limit(2, 1).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.Limit(2, 1).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) } func order(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.OrderBy("id desc").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.OrderBy("id desc").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) - users2 := make([]Userinfo, 0) - err = engine.Asc("id", "username").Desc("height").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users2) + users2 := make([]Userinfo, 0) + err = engine.Asc("id", "username").Desc("height").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users2) } func join(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.Join("LEFT", "userdetail", "userinfo.id=userdetail.id").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + users := make([]Userinfo, 0) + err := engine.Join("LEFT", "userdetail", "userinfo.id=userdetail.id").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } } func having(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.GroupBy("username").Having("username='xlw'").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.GroupBy("username").Having("username='xlw'").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) } func orderSameMapper(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.OrderBy("(id) desc").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.OrderBy("(id) desc").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) - users2 := make([]Userinfo, 0) - err = engine.Asc("(id)", "Username").Desc("Height").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users2) + users2 := make([]Userinfo, 0) + err = engine.Asc("(id)", "Username").Desc("Height").Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users2) } func joinSameMapper(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.Join("LEFT", `"Userdetail"`, `"Userinfo"."id"="Userdetail"."Id"`).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + users := make([]Userinfo, 0) + err := engine.Join("LEFT", `"Userdetail"`, `"Userinfo"."id"="Userdetail"."Id"`).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } } func havingSameMapper(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.GroupBy("Username").Having(`"Username"='xlw'`).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(users) + users := make([]Userinfo, 0) + err := engine.GroupBy("Username").Having(`"Username"='xlw'`).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(users) } func transaction(engine *Engine, t *testing.T) { - counter := func() { - total, err := engine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } - fmt.Printf("----now total %v records\n", total) - } + counter := func() { + total, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } - counter() - defer counter() - session := engine.NewSession() - defer session.Close() + counter() + defer counter() + session := engine.NewSession() + defer session.Close() - err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - } - //session.IsAutoRollback = false - user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} - _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + //session.IsAutoRollback = false + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - user2 := Userinfo{Username: "yyy"} - _, err = session.Where("(id) = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - fmt.Println(err) - //t.Error(err) - return - } + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("(id) = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + fmt.Println(err) + //t.Error(err) + return + } - _, err = session.Delete(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + _, err = session.Delete(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } - // panic(err) !nashtsai! should remove this + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } + // panic(err) !nashtsai! should remove this } func combineTransaction(engine *Engine, t *testing.T) { - counter := func() { - total, err := engine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } - fmt.Printf("----now total %v records\n", total) - } + counter := func() { + total, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } - counter() - defer counter() - session := engine.NewSession() - defer session.Close() + counter() + defer counter() + session := engine.NewSession() + defer session.Close() - err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} - _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } - user2 := Userinfo{Username: "zzz"} - _, err = session.Where("id = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + user2 := Userinfo{Username: "zzz"} + _, err = session.Where("id = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - _, err = session.Exec("delete from userinfo where username = ?", user2.Username) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + _, err = session.Exec("delete from userinfo where username = ?", user2.Username) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } } func combineTransactionSameMapper(engine *Engine, t *testing.T) { - counter := func() { - total, err := engine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - } - fmt.Printf("----now total %v records\n", total) - } + counter := func() { + total, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } - counter() - defer counter() - session := engine.NewSession() - defer session.Close() + counter() + defer counter() + session := engine.NewSession() + defer session.Close() - err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - } - //session.IsAutoRollback = false - user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} - _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } - user2 := Userinfo{Username: "zzz"} - _, err = session.Where("(id) = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + //session.IsAutoRollback = false + user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + user2 := Userinfo{Username: "zzz"} + _, err = session.Where("(id) = ?", 0).Update(&user2) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - _, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + _, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } } func table(engine *Engine, t *testing.T) { - err := engine.DropTables("user_user") - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables("user_user") + if err != nil { + t.Error(err) + panic(err) + } - err = engine.Table("user_user").CreateTable(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.Table("user_user").CreateTable(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } } func createMultiTables(engine *Engine, t *testing.T) { - session := engine.NewSession() - defer session.Close() + session := engine.NewSession() + defer session.Close() - user := &Userinfo{} - err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - } - for i := 0; i < 10; i++ { - tableName := fmt.Sprintf("user_%v", i) + user := &Userinfo{} + err := session.Begin() + if err != nil { + t.Error(err) + panic(err) + } + for i := 0; i < 10; i++ { + tableName := fmt.Sprintf("user_%v", i) - err = session.DropTable(tableName) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + err = session.DropTable(tableName) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } - err = session.Table(tableName).CreateTable(user) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } - } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + err = session.Table(tableName).CreateTable(user) + if err != nil { + session.Rollback() + t.Error(err) + panic(err) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } } func tableOp(engine *Engine, t *testing.T) { - user := Userinfo{Username: "tablexiao", Departname: "dev", Alias: "lunny", Created: time.Now()} - tableName := fmt.Sprintf("user_%v", len(user.Username)) - cnt, err := engine.Table(tableName).Insert(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + user := Userinfo{Username: "tablexiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + tableName := fmt.Sprintf("user_%v", len(user.Username)) + cnt, err := engine.Table(tableName).Insert(&user) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } - has, err := engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("Get has return false") - t.Error(err) - panic(err) - return - } + has, err := engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("Get has return false") + t.Error(err) + panic(err) + return + } - users := make([]Userinfo, 0) - err = engine.Table(tableName).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + users := make([]Userinfo, 0) + err = engine.Table(tableName).Find(&users) + if err != nil { + t.Error(err) + panic(err) + } - id := user.Uid - cnt, err = engine.Table(tableName).Id(id).Update(&Userinfo{Username: "tableda"}) - if err != nil { - t.Error(err) - panic(err) - } + id := user.Uid + cnt, err = engine.Table(tableName).Id(id).Update(&Userinfo{Username: "tableda"}) + if err != nil { + t.Error(err) + panic(err) + } - _, err = engine.Table(tableName).Id(id).Delete(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + _, err = engine.Table(tableName).Id(id).Delete(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } } func testCharst(engine *Engine, t *testing.T) { - err := engine.DropTables("user_charset") - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables("user_charset") + if err != nil { + t.Error(err) + panic(err) + } - err = engine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.Charset("utf8").Table("user_charset").CreateTable(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } } func testStoreEngine(engine *Engine, t *testing.T) { - err := engine.DropTables("user_store_engine") - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables("user_store_engine") + if err != nil { + t.Error(err) + panic(err) + } - err = engine.StoreEngine("InnoDB").Table("user_store_engine").CreateTable(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.StoreEngine("InnoDB").Table("user_store_engine").CreateTable(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } } type tempUser struct { - Id int64 - Username string + Id int64 + Username string } func testCols(engine *Engine, t *testing.T) { - users := []Userinfo{} - err := engine.Cols("id, username").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + users := []Userinfo{} + err := engine.Cols("id, username").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } - fmt.Println(users) + fmt.Println(users) - tmpUsers := []tempUser{} - err = engine.NoCache().Table("userinfo").Cols("id, username").Find(&tmpUsers) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(tmpUsers) + tmpUsers := []tempUser{} + err = engine.NoCache().Table("userinfo").Cols("id, username").Find(&tmpUsers) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(tmpUsers) - user := &Userinfo{Uid: 1, Alias: "", Height: 0} - affected, err := engine.Cols("departname, height").Id(1).Update(user) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println("===================", user, affected) + user := &Userinfo{Uid: 1, Alias: "", Height: 0} + affected, err := engine.Cols("departname, height").Id(1).Update(user) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println("===================", user, affected) } func testColsSameMapper(engine *Engine, t *testing.T) { - users := []Userinfo{} - err := engine.Cols("(id), Username").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + users := []Userinfo{} + err := engine.Cols("(id), Username").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } - fmt.Println(users) + fmt.Println(users) - tmpUsers := []tempUser{} - err = engine.Table("Userinfo").Cols("(id), Username").Find(&tmpUsers) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(tmpUsers) + tmpUsers := []tempUser{} + err = engine.Table("Userinfo").Cols("(id), Username").Find(&tmpUsers) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(tmpUsers) - user := &Userinfo{Uid: 1, Alias: "", Height: 0} - affected, err := engine.Cols("Departname, Height").Update(user) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println("===================", user, affected) + user := &Userinfo{Uid: 1, Alias: "", Height: 0} + affected, err := engine.Cols("Departname, Height").Update(user) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println("===================", user, affected) } type tempUser2 struct { - tempUser `xorm:"extends"` - Departname string + tempUser `xorm:"extends"` + Departname string } func testExtends(engine *Engine, t *testing.T) { - err := engine.DropTables(&tempUser2{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&tempUser2{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&tempUser2{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&tempUser2{}) + if err != nil { + t.Error(err) + panic(err) + } - tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} - _, err = engine.Insert(tu) - if err != nil { - t.Error(err) - panic(err) - } + tu := &tempUser2{tempUser{0, "extends"}, "dev depart"} + _, err = engine.Insert(tu) + if err != nil { + t.Error(err) + panic(err) + } - tu2 := &tempUser2{} - _, err = engine.Get(tu2) - if err != nil { - t.Error(err) - panic(err) - } + tu2 := &tempUser2{} + _, err = engine.Get(tu2) + if err != nil { + t.Error(err) + panic(err) + } - tu3 := &tempUser2{tempUser{0, "extends update"}, ""} - _, err = engine.Id(tu2.Id).Update(tu3) - if err != nil { - t.Error(err) - panic(err) - } + tu3 := &tempUser2{tempUser{0, "extends update"}, ""} + _, err = engine.Id(tu2.Id).Update(tu3) + if err != nil { + t.Error(err) + panic(err) + } } type allCols struct { - Bit int `xorm:"BIT"` - TinyInt int8 `xorm:"TINYINT"` - SmallInt int16 `xorm:"SMALLINT"` - MediumInt int32 `xorm:"MEDIUMINT"` - Int int `xorm:"INT"` - Integer int `xorm:"INTEGER"` - BigInt int64 `xorm:"BIGINT"` + Bit int `xorm:"BIT"` + TinyInt int8 `xorm:"TINYINT"` + SmallInt int16 `xorm:"SMALLINT"` + MediumInt int32 `xorm:"MEDIUMINT"` + Int int `xorm:"INT"` + Integer int `xorm:"INTEGER"` + BigInt int64 `xorm:"BIGINT"` - Char string `xorm:"CHAR(12)"` - Varchar string `xorm:"VARCHAR(54)"` - TinyText string `xorm:"TINYTEXT"` - Text string `xorm:"TEXT"` - MediumText string `xorm:"MEDIUMTEXT"` - LongText string `xorm:"LONGTEXT"` - Binary []byte `xorm:"BINARY(23)"` - VarBinary []byte `xorm:"VARBINARY(12)"` + Char string `xorm:"CHAR(12)"` + Varchar string `xorm:"VARCHAR(54)"` + TinyText string `xorm:"TINYTEXT"` + Text string `xorm:"TEXT"` + MediumText string `xorm:"MEDIUMTEXT"` + LongText string `xorm:"LONGTEXT"` + Binary []byte `xorm:"BINARY(23)"` + VarBinary []byte `xorm:"VARBINARY(12)"` - Date time.Time `xorm:"DATE"` - DateTime time.Time `xorm:"DATETIME"` - Time time.Time `xorm:"TIME"` - TimeStamp time.Time `xorm:"TIMESTAMP"` + Date time.Time `xorm:"DATE"` + DateTime time.Time `xorm:"DATETIME"` + Time time.Time `xorm:"TIME"` + TimeStamp time.Time `xorm:"TIMESTAMP"` - Decimal float64 `xorm:"DECIMAL"` - Numeric float64 `xorm:"NUMERIC"` + Decimal float64 `xorm:"DECIMAL"` + Numeric float64 `xorm:"NUMERIC"` - Real float32 `xorm:"REAL"` - Float float32 `xorm:"FLOAT"` - Double float64 `xorm:"DOUBLE"` + Real float32 `xorm:"REAL"` + Float float32 `xorm:"FLOAT"` + Double float64 `xorm:"DOUBLE"` - TinyBlob []byte `xorm:"TINYBLOB"` - Blob []byte `xorm:"BLOB"` - MediumBlob []byte `xorm:"MEDIUMBLOB"` - LongBlob []byte `xorm:"LONGBLOB"` - Bytea []byte `xorm:"BYTEA"` + TinyBlob []byte `xorm:"TINYBLOB"` + Blob []byte `xorm:"BLOB"` + MediumBlob []byte `xorm:"MEDIUMBLOB"` + LongBlob []byte `xorm:"LONGBLOB"` + Bytea []byte `xorm:"BYTEA"` - Bool bool `xorm:"BOOL"` + Bool bool `xorm:"BOOL"` - Serial int `xorm:"SERIAL"` - //BigSerial int64 `xorm:"BIGSERIAL"` + Serial int `xorm:"SERIAL"` + //BigSerial int64 `xorm:"BIGSERIAL"` } func testColTypes(engine *Engine, t *testing.T) { - err := engine.DropTables(&allCols{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&allCols{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&allCols{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&allCols{}) + if err != nil { + t.Error(err) + panic(err) + } - ac := &allCols{ - 1, - 4, - 8, - 16, - 32, - 64, - 128, + ac := &allCols{ + 1, + 4, + 8, + 16, + 32, + 64, + 128, - "123", - "fafdafa", - "fafafafdsafdsafdaf", - "fdsafafdsafdsaf", - "fafdsafdsafdsfadasfsfafd", - "fadfdsafdsafasfdasfds", - []byte("fdafsafdasfdsafsa"), - []byte("fdsafsdafs"), + "123", + "fafdafa", + "fafafafdsafdsafdaf", + "fdsafafdsafdsaf", + "fafdsafdsafdsfadasfsfafd", + "fadfdsafdsafasfdasfds", + []byte("fdafsafdasfdsafsa"), + []byte("fdsafsdafs"), - time.Now(), - time.Now(), - time.Now(), - time.Now(), + time.Now(), + time.Now(), + time.Now(), + time.Now(), - 1.34, - 2.44302346, + 1.34, + 2.44302346, - 1.3344, - 2.59693523, - 3.2342523543, + 1.3344, + 2.59693523, + 3.2342523543, - []byte("fafdasf"), - []byte("fafdfdsafdsafasf"), - []byte("faffadsfdsdasf"), - []byte("faffdasfdsadasf"), - []byte("fafasdfsadffdasf"), + []byte("fafdasf"), + []byte("fafdfdsafdsafasf"), + []byte("faffadsfdsdasf"), + []byte("faffdasfdsadasf"), + []byte("fafasdfsadffdasf"), - true, + true, - 21, - } + 21, + } - cnt, err := engine.Insert(ac) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert return not 1") - t.Error(err) - panic(err) - } - newAc := &allCols{} - has, err := engine.Get(newAc) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("error no ideas") - t.Error(err) - panic(err) - } + cnt, err := engine.Insert(ac) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert return not 1") + t.Error(err) + panic(err) + } + newAc := &allCols{} + has, err := engine.Get(newAc) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("error no ideas") + t.Error(err) + panic(err) + } - // don't use this type as query condition - newAc.Real = 0 - newAc.Float = 0 - newAc.Double = 0 - cnt, err = engine.Delete(newAc) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New(fmt.Sprintf("delete error, deleted counts is %v", cnt)) - t.Error(err) - panic(err) - } + // don't use this type as query condition + newAc.Real = 0 + newAc.Float = 0 + newAc.Double = 0 + cnt, err = engine.Delete(newAc) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New(fmt.Sprintf("delete error, deleted counts is %v", cnt)) + t.Error(err) + panic(err) + } } type MyInt int @@ -1230,2204 +1230,2203 @@ func (s *MyString) ToDB() ([]byte, error) { }*/ type MyStruct struct { - Type MyInt - U MyUInt - F MyFloat - S MyString - IA []MyInt - UA []MyUInt - FA []MyFloat - SA []MyString - NameArray []string - Name string - UIA []uint - UIA8 []uint8 - UIA16 []uint16 - UIA32 []uint32 - UIA64 []uint64 - UI uint - //C64 complex64 - MSS map[string]string + Type MyInt + U MyUInt + F MyFloat + S MyString + IA []MyInt + UA []MyUInt + FA []MyFloat + SA []MyString + NameArray []string + Name string + UIA []uint + UIA8 []uint8 + UIA16 []uint16 + UIA32 []uint32 + UIA64 []uint64 + UI uint + //C64 complex64 + MSS map[string]string } func testCustomType(engine *Engine, t *testing.T) { - err := engine.DropTables(&MyStruct{}) - if err != nil { - t.Error(err) - panic(err) - return - } + err := engine.DropTables(&MyStruct{}) + if err != nil { + t.Error(err) + panic(err) + return + } - err = engine.CreateTables(&MyStruct{}) - i := MyStruct{Name: "Test", Type: MyInt(1)} - i.U = 23 - i.F = 1.34 - i.S = "fafdsafdsaf" - i.UI = 2 - i.IA = []MyInt{1, 3, 5} - i.UIA = []uint{1, 3} - i.UIA16 = []uint16{2} - i.UIA32 = []uint32{4, 5} - i.UIA64 = []uint64{6, 7, 9} - i.UIA8 = []uint8{1, 2, 3, 4} - i.NameArray = []string{"ssss", "fsdf", "lllll, ss"} - i.MSS = map[string]string{"s": "sfds,ss", "x": "lfjljsl"} - cnt, err := engine.Insert(&i) - if err != nil { - t.Error(err) - panic(err) - return - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + err = engine.CreateTables(&MyStruct{}) + i := MyStruct{Name: "Test", Type: MyInt(1)} + i.U = 23 + i.F = 1.34 + i.S = "fafdsafdsaf" + i.UI = 2 + i.IA = []MyInt{1, 3, 5} + i.UIA = []uint{1, 3} + i.UIA16 = []uint16{2} + i.UIA32 = []uint32{4, 5} + i.UIA64 = []uint64{6, 7, 9} + i.UIA8 = []uint8{1, 2, 3, 4} + i.NameArray = []string{"ssss", "fsdf", "lllll, ss"} + i.MSS = map[string]string{"s": "sfds,ss", "x": "lfjljsl"} + cnt, err := engine.Insert(&i) + if err != nil { + t.Error(err) + panic(err) + return + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } - fmt.Println(i) - has, err := engine.Get(&i) - if err != nil { - t.Error(err) - panic(err) - } else if !has { - t.Error(errors.New("should get one record")) - panic(err) - } + fmt.Println(i) + has, err := engine.Get(&i) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("should get one record")) + panic(err) + } - ss := []MyStruct{} - err = engine.Find(&ss) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(ss) + ss := []MyStruct{} + err = engine.Find(&ss) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(ss) - sss := MyStruct{} - has, err = engine.Get(&sss) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(sss) + sss := MyStruct{} + has, err = engine.Get(&sss) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(sss) - if has { - cnt, err := engine.Delete(&sss) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - t.Error(errors.New("delete error")) - panic(err) - } - } + if has { + cnt, err := engine.Delete(&sss) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + t.Error(errors.New("delete error")) + panic(err) + } + } } type UserCU struct { - Id int64 - Name string - Created time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` + Id int64 + Name string + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` } func testCreatedAndUpdated(engine *Engine, t *testing.T) { - u := new(UserCU) - err := engine.DropTables(u) - if err != nil { - t.Error(err) - panic(err) - } + u := new(UserCU) + err := engine.DropTables(u) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(u) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(u) + if err != nil { + t.Error(err) + panic(err) + } - u.Name = "sss" - cnt, err := engine.Insert(u) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + u.Name = "sss" + cnt, err := engine.Insert(u) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } - u.Name = "xxx" - cnt, err = engine.Id(u.Id).Update(u) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + u.Name = "xxx" + cnt, err = engine.Id(u.Id).Update(u) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("update not returned 1") + t.Error(err) + panic(err) + return + } - u.Id = 0 - u.Created = time.Now().Add(-time.Hour * 24 * 365) - u.Updated = u.Created - fmt.Println(u) - cnt, err = engine.NoAutoTime().Insert(u) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + u.Id = 0 + u.Created = time.Now().Add(-time.Hour * 24 * 365) + u.Updated = u.Created + fmt.Println(u) + cnt, err = engine.NoAutoTime().Insert(u) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } } type IndexOrUnique struct { - Id int64 - Index int `xorm:"index"` - Unique int `xorm:"unique"` - Group1 int `xorm:"index(ttt)"` - Group2 int `xorm:"index(ttt)"` - UniGroup1 int `xorm:"unique(lll)"` - UniGroup2 int `xorm:"unique(lll)"` + Id int64 + Index int `xorm:"index"` + Unique int `xorm:"unique"` + Group1 int `xorm:"index(ttt)"` + Group2 int `xorm:"index(ttt)"` + UniGroup1 int `xorm:"unique(lll)"` + UniGroup2 int `xorm:"unique(lll)"` } func testIndexAndUnique(engine *Engine, t *testing.T) { - err := engine.CreateTables(&IndexOrUnique{}) - if err != nil { - t.Error(err) - //panic(err) - } + err := engine.CreateTables(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } - err = engine.DropTables(&IndexOrUnique{}) - if err != nil { - t.Error(err) - //panic(err) - } + err = engine.DropTables(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } - err = engine.CreateTables(&IndexOrUnique{}) - if err != nil { - t.Error(err) - //panic(err) - } + err = engine.CreateTables(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } - err = engine.CreateIndexes(&IndexOrUnique{}) - if err != nil { - t.Error(err) - //panic(err) - } + err = engine.CreateIndexes(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } - err = engine.CreateUniques(&IndexOrUnique{}) - if err != nil { - t.Error(err) - //panic(err) - } + err = engine.CreateUniques(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } } type IntId struct { - Id int - Name string + Id int + Name string } type Int32Id struct { - Id int32 - Name string + Id int32 + Name string } func testIntId(engine *Engine, t *testing.T) { - err := engine.DropTables(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&IntId{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&IntId{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&IntId{}) + if err != nil { + t.Error(err) + panic(err) + } - _, err = engine.Insert(&IntId{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } + _, err = engine.Insert(&IntId{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } } func testInt32Id(engine *Engine, t *testing.T) { - err := engine.DropTables(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&Int32Id{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&Int32Id{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&Int32Id{}) + if err != nil { + t.Error(err) + panic(err) + } - _, err = engine.Insert(&Int32Id{Name: "test"}) - if err != nil { - t.Error(err) - panic(err) - } + _, err = engine.Insert(&Int32Id{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } } func testMetaInfo(engine *Engine, t *testing.T) { - tables, err := engine.DBMetas() - if err != nil { - t.Error(err) - panic(err) - } + tables, err := engine.DBMetas() + if err != nil { + t.Error(err) + panic(err) + } - for _, table := range tables { - fmt.Println(table.Name) - for _, col := range table.Columns { - fmt.Println(col.String(engine.dialect)) - } + for _, table := range tables { + fmt.Println(table.Name) + for _, col := range table.Columns { + fmt.Println(col.String(engine.dialect)) + } - for _, index := range table.Indexes { - fmt.Println(index.Name, index.Type, strings.Join(index.Cols, ",")) - } - } + for _, index := range table.Indexes { + fmt.Println(index.Name, index.Type, strings.Join(index.Cols, ",")) + } + } } func testIterate(engine *Engine, t *testing.T) { - err := engine.Omit("is_man").Iterate(new(Userinfo), func(idx int, bean interface{}) error { - user := bean.(*Userinfo) - fmt.Println(idx, "--", user) - return nil - }) + err := engine.Omit("is_man").Iterate(new(Userinfo), func(idx int, bean interface{}) error { + user := bean.(*Userinfo) + fmt.Println(idx, "--", user) + return nil + }) - if err != nil { - t.Error(err) - panic(err) - } + if err != nil { + t.Error(err) + panic(err) + } } type StrangeName struct { - Id_t int64 `xorm:"pk autoincr"` - Name string + Id_t int64 `xorm:"pk autoincr"` + Name string } func testStrangeName(engine *Engine, t *testing.T) { - err := engine.DropTables(new(StrangeName)) - if err != nil { - t.Error(err) - } + err := engine.DropTables(new(StrangeName)) + if err != nil { + t.Error(err) + } - err = engine.CreateTables(new(StrangeName)) - if err != nil { - t.Error(err) - } + err = engine.CreateTables(new(StrangeName)) + if err != nil { + t.Error(err) + } - _, err = engine.Insert(&StrangeName{Name: "sfsfdsfds"}) - if err != nil { - t.Error(err) - } + _, err = engine.Insert(&StrangeName{Name: "sfsfdsfds"}) + if err != nil { + t.Error(err) + } - beans := make([]StrangeName, 0) - err = engine.Find(&beans) - if err != nil { - t.Error(err) - } + beans := make([]StrangeName, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + } } type Version struct { - Id int64 - Name string - Ver int `xorm:"version"` + Id int64 + Name string + Ver int `xorm:"version"` } func testVersion(engine *Engine, t *testing.T) { - err := engine.DropTables(new(Version)) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(new(Version)) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(new(Version)) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(new(Version)) + if err != nil { + t.Error(err) + panic(err) + } - ver := &Version{Name: "sfsfdsfds"} - _, err = engine.Insert(ver) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(ver) - if ver.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } + ver := &Version{Name: "sfsfdsfds"} + _, err = engine.Insert(ver) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(ver) + if ver.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } - newVer := new(Version) - has, err := engine.Id(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } + newVer := new(Version) + has, err := engine.Id(ver.Id).Get(newVer) + if err != nil { + t.Error(err) + panic(err) + } - if !has { - t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 1 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } + if !has { + t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id))) + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 1 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } - newVer.Name = "-------" - _, err = engine.Id(ver.Id).Update(newVer) - if err != nil { - t.Error(err) - panic(err) - } + newVer.Name = "-------" + _, err = engine.Id(ver.Id).Update(newVer) + if err != nil { + t.Error(err) + panic(err) + } - newVer = new(Version) - has, err = engine.Id(ver.Id).Get(newVer) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println(newVer) - if newVer.Ver != 2 { - err = errors.New("insert error") - t.Error(err) - panic(err) - } + newVer = new(Version) + has, err = engine.Id(ver.Id).Get(newVer) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println(newVer) + if newVer.Ver != 2 { + err = errors.New("insert error") + t.Error(err) + panic(err) + } - /* - newVer.Name = "-------" - _, err = engine.Id(ver.Id).Update(newVer) - if err != nil { - t.Error(err) - return - }*/ + /* + newVer.Name = "-------" + _, err = engine.Id(ver.Id).Update(newVer) + if err != nil { + t.Error(err) + return + }*/ } func testDistinct(engine *Engine, t *testing.T) { - users := make([]Userinfo, 0) - err := engine.Distinct("departname").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - t.Error(err) - panic(errors.New("should be one record")) - } + users := make([]Userinfo, 0) + err := engine.Distinct("departname").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + if len(users) != 1 { + t.Error(err) + panic(errors.New("should be one record")) + } - fmt.Println(users) + fmt.Println(users) - type Depart struct { - Departname string - } + type Depart struct { + Departname string + } - users2 := make([]Depart, 0) - err = engine.Distinct("departname").Table(new(Userinfo)).Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } - if len(users2) != 1 { - t.Error(err) - panic(errors.New("should be one record")) - } - fmt.Println(users2) + users2 := make([]Depart, 0) + err = engine.Distinct("departname").Table(new(Userinfo)).Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + if len(users2) != 1 { + t.Error(err) + panic(errors.New("should be one record")) + } + fmt.Println(users2) } func testUseBool(engine *Engine, t *testing.T) { - cnt1, err := engine.Count(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + cnt1, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } - users := make([]Userinfo, 0) - err = engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - var fNumber int64 - for _, u := range users { - if u.IsMan == false { - fNumber += 1 - } - } + users := make([]Userinfo, 0) + err = engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + var fNumber int64 + for _, u := range users { + if u.IsMan == false { + fNumber += 1 + } + } - cnt2, err := engine.UseBool().Update(&Userinfo{IsMan: true}) - if err != nil { - t.Error(err) - panic(err) - } - if fNumber != cnt2 { - fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2) - /*err = errors.New("Updated number is not corrected.") - t.Error(err) - panic(err)*/ - } + cnt2, err := engine.UseBool().Update(&Userinfo{IsMan: true}) + if err != nil { + t.Error(err) + panic(err) + } + if fNumber != cnt2 { + fmt.Println("cnt1", cnt1, "fNumber", fNumber, "cnt2", cnt2) + /*err = errors.New("Updated number is not corrected.") + t.Error(err) + panic(err)*/ + } - _, err = engine.Update(&Userinfo{IsMan: true}) - if err == nil { - err = errors.New("error condition") - t.Error(err) - panic(err) - } + _, err = engine.Update(&Userinfo{IsMan: true}) + if err == nil { + err = errors.New("error condition") + t.Error(err) + panic(err) + } } func testBool(engine *Engine, t *testing.T) { - _, err := engine.UseBool().Update(&Userinfo{IsMan: true}) - if err != nil { - t.Error(err) - panic(err) - } - users := make([]Userinfo, 0) - err = engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - if !user.IsMan { - err = errors.New("update bool or find bool error") - t.Error(err) - panic(err) - } - } + _, err := engine.UseBool().Update(&Userinfo{IsMan: true}) + if err != nil { + t.Error(err) + panic(err) + } + users := make([]Userinfo, 0) + err = engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + if !user.IsMan { + err = errors.New("update bool or find bool error") + t.Error(err) + panic(err) + } + } - _, err = engine.UseBool().Update(&Userinfo{IsMan: false}) - if err != nil { - t.Error(err) - panic(err) - } - users = make([]Userinfo, 0) - err = engine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - for _, user := range users { - if user.IsMan { - err = errors.New("update bool or find bool error") - t.Error(err) - panic(err) - } - } + _, err = engine.UseBool().Update(&Userinfo{IsMan: false}) + if err != nil { + t.Error(err) + panic(err) + } + users = make([]Userinfo, 0) + err = engine.Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + for _, user := range users { + if user.IsMan { + err = errors.New("update bool or find bool error") + t.Error(err) + panic(err) + } + } } type TTime struct { - Id int64 - T time.Time - Tz time.Time `xorm:"timestampz"` + Id int64 + T time.Time + Tz time.Time `xorm:"timestampz"` } func testTime(engine *Engine, t *testing.T) { - err := engine.Sync(&TTime{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.Sync(&TTime{}) + if err != nil { + t.Error(err) + panic(err) + } - tt := &TTime{} - _, err = engine.Insert(tt) - if err != nil { - t.Error(err) - panic(err) - } + tt := &TTime{} + _, err = engine.Insert(tt) + if err != nil { + t.Error(err) + panic(err) + } - tt2 := &TTime{Id: tt.Id} - has, err := engine.Get(tt2) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("no record error") - t.Error(err) - panic(err) - } + tt2 := &TTime{Id: tt.Id} + has, err := engine.Get(tt2) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("no record error") + t.Error(err) + panic(err) + } - tt3 := &TTime{T: time.Now(), Tz: time.Now()} - _, err = engine.Insert(tt3) - if err != nil { - t.Error(err) - panic(err) - } + tt3 := &TTime{T: time.Now(), Tz: time.Now()} + _, err = engine.Insert(tt3) + if err != nil { + t.Error(err) + panic(err) + } - tt4s := make([]TTime, 0) - err = engine.Find(&tt4s) - if err != nil { - t.Error(err) - panic(err) - } - fmt.Println("=======\n", tt4s, "=======\n") + tt4s := make([]TTime, 0) + err = engine.Find(&tt4s) + if err != nil { + t.Error(err) + panic(err) + } + fmt.Println("=======\n", tt4s, "=======\n") } func testPrefixTableName(engine *Engine, t *testing.T) { - tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) - if err != nil { - t.Error(err) - panic(err) - } - tempEngine.ShowSQL = true - mapper := NewPrefixMapper(SnakeMapper{}, "xlw_") - //tempEngine.SetMapper(mapper) - tempEngine.SetTableMapper(mapper) - exist, err := tempEngine.IsTableExist(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if exist { - err = tempEngine.DropTables(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - } - err = tempEngine.CreateTables(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } + tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + if err != nil { + t.Error(err) + panic(err) + } + tempEngine.ShowSQL = true + mapper := NewPrefixMapper(SnakeMapper{}, "xlw_") + //tempEngine.SetMapper(mapper) + tempEngine.SetTableMapper(mapper) + exist, err := tempEngine.IsTableExist(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + if exist { + err = tempEngine.DropTables(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } + } + err = tempEngine.CreateTables(&Userinfo{}) + if err != nil { + t.Error(err) + panic(err) + } } type CreatedUpdated struct { - Id int64 - Name string - Value float64 `xorm:"numeric"` - Created time.Time `xorm:"created"` - Created2 time.Time `xorm:"created"` - Updated time.Time `xorm:"updated"` + Id int64 + Name string + Value float64 `xorm:"numeric"` + Created time.Time `xorm:"created"` + Created2 time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` } func testCreatedUpdated(engine *Engine, t *testing.T) { - err := engine.Sync(&CreatedUpdated{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.Sync(&CreatedUpdated{}) + if err != nil { + t.Error(err) + panic(err) + } - c := &CreatedUpdated{Name: "test"} - _, err = engine.Insert(c) - if err != nil { - t.Error(err) - panic(err) - } + c := &CreatedUpdated{Name: "test"} + _, err = engine.Insert(c) + if err != nil { + t.Error(err) + panic(err) + } - c2 := new(CreatedUpdated) - has, err := engine.Id(c.Id).Get(c2) - if err != nil { - t.Error(err) - panic(err) - } + c2 := new(CreatedUpdated) + has, err := engine.Id(c.Id).Get(c2) + if err != nil { + t.Error(err) + panic(err) + } - if !has { - panic(errors.New("no id")) - } + if !has { + panic(errors.New("no id")) + } - c2.Value -= 1 - _, err = engine.Id(c2.Id).Update(c2) - if err != nil { - t.Error(err) - panic(err) - } + c2.Value -= 1 + _, err = engine.Id(c2.Id).Update(c2) + if err != nil { + t.Error(err) + panic(err) + } } type ProcessorsStruct struct { - Id int64 + Id int64 - B4InsertFlag int - AfterInsertedFlag int - B4UpdateFlag int - AfterUpdatedFlag int - B4DeleteFlag int `xorm:"-"` - AfterDeletedFlag int `xorm:"-"` + B4InsertFlag int + AfterInsertedFlag int + B4UpdateFlag int + AfterUpdatedFlag int + B4DeleteFlag int `xorm:"-"` + AfterDeletedFlag int `xorm:"-"` - B4InsertViaExt int - AfterInsertedViaExt int - B4UpdateViaExt int - AfterUpdatedViaExt int - B4DeleteViaExt int `xorm:"-"` - AfterDeletedViaExt int `xorm:"-"` + B4InsertViaExt int + AfterInsertedViaExt int + B4UpdateViaExt int + AfterUpdatedViaExt int + B4DeleteViaExt int `xorm:"-"` + AfterDeletedViaExt int `xorm:"-"` } func (p *ProcessorsStruct) BeforeInsert() { - p.B4InsertFlag = 1 + p.B4InsertFlag = 1 } func (p *ProcessorsStruct) BeforeUpdate() { - p.B4UpdateFlag = 1 + p.B4UpdateFlag = 1 } func (p *ProcessorsStruct) BeforeDelete() { - p.B4DeleteFlag = 1 + p.B4DeleteFlag = 1 } func (p *ProcessorsStruct) AfterInsert() { - p.AfterInsertedFlag = 1 + p.AfterInsertedFlag = 1 } func (p *ProcessorsStruct) AfterUpdate() { - p.AfterUpdatedFlag = 1 + p.AfterUpdatedFlag = 1 } func (p *ProcessorsStruct) AfterDelete() { - p.AfterDeletedFlag = 1 + p.AfterDeletedFlag = 1 } func testProcessors(engine *Engine, t *testing.T) { - // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) - // if err != nil { - // t.Error(err) - // panic(err) - // } + // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + // if err != nil { + // t.Error(err) + // panic(err) + // } - engine.ShowSQL = true - err := engine.DropTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } - p := &ProcessorsStruct{} + engine.ShowSQL = true + err := engine.DropTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } + p := &ProcessorsStruct{} - err = engine.CreateTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } - b4InsertFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.B4InsertViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + b4InsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4InsertViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - afterInsertFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.AfterInsertedViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + afterInsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterInsertedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - _, err = engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedViaExt == 0 { - t.Error(errors.New("AfterInsertedViaExt not set")) - } - } + _, err = engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } - p2 := &ProcessorsStruct{} - _, err = engine.Id(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p2.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p2.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p2.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - } - // -- + p2 := &ProcessorsStruct{} + _, err = engine.Id(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + // -- - // test update processors - b4UpdateFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.B4UpdateViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + // test update processors + b4UpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4UpdateViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - afterUpdateFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.AfterUpdatedViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + afterUpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterUpdatedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - p = p2 // reset + p = p2 // reset - _, err = engine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } - } + _, err = engine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } - p2 = &ProcessorsStruct{} - _, err = engine.Id(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p2.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) - } - if p2.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p2.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) - } - } - // -- + p2 = &ProcessorsStruct{} + _, err = engine.Id(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set: " + string(p.AfterUpdatedFlag))) + } + if p2.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set: " + string(p.AfterUpdatedViaExt))) + } + } + // -- - // test delete processors - b4DeleteFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.B4DeleteViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + // test delete processors + b4DeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4DeleteViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - afterDeleteFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.AfterDeletedViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + afterDeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterDeletedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - p = p2 // reset - _, err = engine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag == 0 { - t.Error(errors.New("AfterDeletedFlag not set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt == 0 { - t.Error(errors.New("AfterDeletedViaExt not set")) - } - } - // -- + p = p2 // reset + _, err = engine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) + } + } + // -- - // test insert multi - pslice := make([]*ProcessorsStruct, 0) - pslice = append(pslice, &ProcessorsStruct{}) - pslice = append(pslice, &ProcessorsStruct{}) - cnt, err := engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) - if err != nil { - t.Error(err) - panic(err) - } else { - if cnt != 2 { - t.Error(errors.New("incorrect insert count")) - } - for _, elem := range pslice { - if elem.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if elem.AfterInsertedFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if elem.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if elem.AfterInsertedViaExt == 0 { - t.Error(errors.New("AfterInsertedViaExt not set")) - } - } - } + // test insert multi + pslice := make([]*ProcessorsStruct, 0) + pslice = append(pslice, &ProcessorsStruct{}) + pslice = append(pslice, &ProcessorsStruct{}) + cnt, err := engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) + if err != nil { + t.Error(err) + panic(err) + } else { + if cnt != 2 { + t.Error(errors.New("incorrect insert count")) + } + for _, elem := range pslice { + if elem.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.AfterInsertedFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if elem.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } + } - for _, elem := range pslice { - p = &ProcessorsStruct{} - _, err = engine.Id(elem.Id).Get(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p2.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p2.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p2.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - } - } - // -- + for _, elem := range pslice { + p = &ProcessorsStruct{} + _, err = engine.Id(elem.Id).Get(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + } + // -- } func testProcessorsTx(engine *Engine, t *testing.T) { - // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) - // if err != nil { - // t.Error(err) - // panic(err) - // } + // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + // if err != nil { + // t.Error(err) + // panic(err) + // } - // tempEngine.ShowSQL = true - err := engine.DropTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + // tempEngine.ShowSQL = true + err := engine.DropTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&ProcessorsStruct{}) + if err != nil { + t.Error(err) + panic(err) + } - // test insert processors with tx rollback - session := engine.NewSession() - err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + // test insert processors with tx rollback + session := engine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - p := &ProcessorsStruct{} - b4InsertFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.B4InsertViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + p := &ProcessorsStruct{} + b4InsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4InsertViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - afterInsertFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.AfterInsertedViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } - _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("B4InsertFlag is set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - } + afterInsertFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterInsertedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } + _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } - err = session.Rollback() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("B4InsertFlag is set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - } - session.Close() - p2 := &ProcessorsStruct{} - _, err = engine.Id(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.Id > 0 { - err = errors.New("tx got committed upon insert!?") - t.Error(err) - panic(err) - } - } - // -- + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("B4InsertFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + session.Close() + p2 := &ProcessorsStruct{} + _, err = engine.Id(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.Id > 0 { + err = errors.New("tx got committed upon insert!?") + t.Error(err) + panic(err) + } + } + // -- - // test insert processors with tx commit - session = engine.NewSession() - err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + // test insert processors with tx commit + session = engine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - p = &ProcessorsStruct{} - _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - } + p = &ProcessorsStruct{} + _, err = session.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p.AfterInsertedFlag == 0 { - t.Error(errors.New("AfterInsertedFlag not set")) - } - if p.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p.AfterInsertedViaExt == 0 { - t.Error(errors.New("AfterInsertedViaExt not set")) - } - } - session.Close() - p2 = &ProcessorsStruct{} - _, err = engine.Id(p.Id).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4InsertFlag == 0 { - t.Error(errors.New("B4InsertFlag not set")) - } - if p2.AfterInsertedFlag != 0 { - t.Error(errors.New("AfterInsertedFlag is set")) - } - if p2.B4InsertViaExt == 0 { - t.Error(errors.New("B4InsertViaExt not set")) - } - if p2.AfterInsertedViaExt != 0 { - t.Error(errors.New("AfterInsertedViaExt is set")) - } - } - insertedId := p2.Id - // -- + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p.AfterInsertedFlag == 0 { + t.Error(errors.New("AfterInsertedFlag not set")) + } + if p.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p.AfterInsertedViaExt == 0 { + t.Error(errors.New("AfterInsertedViaExt not set")) + } + } + session.Close() + p2 = &ProcessorsStruct{} + _, err = engine.Id(p.Id).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4InsertFlag == 0 { + t.Error(errors.New("B4InsertFlag not set")) + } + if p2.AfterInsertedFlag != 0 { + t.Error(errors.New("AfterInsertedFlag is set")) + } + if p2.B4InsertViaExt == 0 { + t.Error(errors.New("B4InsertViaExt not set")) + } + if p2.AfterInsertedViaExt != 0 { + t.Error(errors.New("AfterInsertedViaExt is set")) + } + } + insertedId := p2.Id + // -- - // test update processors with tx rollback - session = engine.NewSession() - err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + // test update processors with tx rollback + session = engine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - b4UpdateFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.B4UpdateViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + b4UpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4UpdateViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - afterUpdateFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.AfterUpdatedViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + afterUpdateFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterUpdatedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - p = p2 // reset + p = p2 // reset - _, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } - } - err = session.Rollback() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } - } + _, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } - session.Close() - p2 = &ProcessorsStruct{} - _, err = engine.Id(insertedId).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4UpdateFlag != 0 { - t.Error(errors.New("B4UpdateFlag is set")) - } - if p2.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p2.B4UpdateViaExt != 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p2.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } - } - // -- + session.Close() + p2 = &ProcessorsStruct{} + _, err = engine.Id(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4UpdateFlag != 0 { + t.Error(errors.New("B4UpdateFlag is set")) + } + if p2.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p2.B4UpdateViaExt != 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p2.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + // -- - // test update processors with tx commit - session = engine.NewSession() - err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + // test update processors with tx commit + session = engine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - p = &ProcessorsStruct{} + p = &ProcessorsStruct{} - _, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag != 0 { - t.Error(errors.New("AfterUpdatedFlag is set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt != 0 { - t.Error(errors.New("AfterUpdatedViaExt is set")) - } - } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } - } - session.Close() - p2 = &ProcessorsStruct{} - _, err = engine.Id(insertedId).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4UpdateFlag == 0 { - t.Error(errors.New("B4UpdateFlag not set")) - } - if p.AfterUpdatedFlag == 0 { - t.Error(errors.New("AfterUpdatedFlag not set")) - } - if p.B4UpdateViaExt == 0 { - t.Error(errors.New("B4UpdateViaExt not set")) - } - if p.AfterUpdatedViaExt == 0 { - t.Error(errors.New("AfterUpdatedViaExt not set")) - } - } - // -- + _, err = session.Id(insertedId).Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag != 0 { + t.Error(errors.New("AfterUpdatedFlag is set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt != 0 { + t.Error(errors.New("AfterUpdatedViaExt is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + session.Close() + p2 = &ProcessorsStruct{} + _, err = engine.Id(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4UpdateFlag == 0 { + t.Error(errors.New("B4UpdateFlag not set")) + } + if p.AfterUpdatedFlag == 0 { + t.Error(errors.New("AfterUpdatedFlag not set")) + } + if p.B4UpdateViaExt == 0 { + t.Error(errors.New("B4UpdateViaExt not set")) + } + if p.AfterUpdatedViaExt == 0 { + t.Error(errors.New("AfterUpdatedViaExt not set")) + } + } + // -- - // test delete processors with tx rollback - session = engine.NewSession() - err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + // test delete processors with tx rollback + session = engine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - b4DeleteFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.B4DeleteViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + b4DeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.B4DeleteViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - afterDeleteFunc := func(bean interface{}) { - if v, ok := (bean).(*ProcessorsStruct); ok { - v.AfterDeletedViaExt = 1 - } else { - t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) - } - } + afterDeleteFunc := func(bean interface{}) { + if v, ok := (bean).(*ProcessorsStruct); ok { + v.AfterDeletedViaExt = 1 + } else { + t.Error(errors.New("cast to ProcessorsStruct failed, how can this be!?")) + } + } - p = &ProcessorsStruct{} // reset + p = &ProcessorsStruct{} // reset - _, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } - } - err = session.Rollback() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } - } - session.Close() + _, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + err = session.Rollback() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + session.Close() - p2 = &ProcessorsStruct{} - _, err = engine.Id(insertedId).Get(p2) - if err != nil { - t.Error(err) - panic(err) - } else { - if p2.B4DeleteFlag != 0 { - t.Error(errors.New("B4DeleteFlag is set")) - } - if p2.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p2.B4DeleteViaExt != 0 { - t.Error(errors.New("B4DeleteViaExt is set")) - } - if p2.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } - } - // -- + p2 = &ProcessorsStruct{} + _, err = engine.Id(insertedId).Get(p2) + if err != nil { + t.Error(err) + panic(err) + } else { + if p2.B4DeleteFlag != 0 { + t.Error(errors.New("B4DeleteFlag is set")) + } + if p2.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p2.B4DeleteViaExt != 0 { + t.Error(errors.New("B4DeleteViaExt is set")) + } + if p2.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + // -- - // test delete processors with tx commit - session = engine.NewSession() - err = session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + // test delete processors with tx commit + session = engine.NewSession() + err = session.Begin() + if err != nil { + t.Error(err) + panic(err) + } - p = &ProcessorsStruct{} + p = &ProcessorsStruct{} - _, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag != 0 { - t.Error(errors.New("AfterDeletedFlag is set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt != 0 { - t.Error(errors.New("AfterDeletedViaExt is set")) - } - } - err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } else { - if p.B4DeleteFlag == 0 { - t.Error(errors.New("B4DeleteFlag not set")) - } - if p.AfterDeletedFlag == 0 { - t.Error(errors.New("AfterDeletedFlag not set")) - } - if p.B4DeleteViaExt == 0 { - t.Error(errors.New("B4DeleteViaExt not set")) - } - if p.AfterDeletedViaExt == 0 { - t.Error(errors.New("AfterDeletedViaExt not set")) - } - } - session.Close() - // -- + _, err = session.Id(insertedId).Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag != 0 { + t.Error(errors.New("AfterDeletedFlag is set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt != 0 { + t.Error(errors.New("AfterDeletedViaExt is set")) + } + } + err = session.Commit() + if err != nil { + t.Error(err) + panic(err) + } else { + if p.B4DeleteFlag == 0 { + t.Error(errors.New("B4DeleteFlag not set")) + } + if p.AfterDeletedFlag == 0 { + t.Error(errors.New("AfterDeletedFlag not set")) + } + if p.B4DeleteViaExt == 0 { + t.Error(errors.New("B4DeleteViaExt not set")) + } + if p.AfterDeletedViaExt == 0 { + t.Error(errors.New("AfterDeletedViaExt not set")) + } + } + session.Close() + // -- } type NullData struct { - Id int64 - StringPtr *string - StringPtr2 *string `xorm:"text"` - BoolPtr *bool - BytePtr *byte - UintPtr *uint - Uint8Ptr *uint8 - Uint16Ptr *uint16 - Uint32Ptr *uint32 - Uint64Ptr *uint64 - IntPtr *int - Int8Ptr *int8 - Int16Ptr *int16 - Int32Ptr *int32 - Int64Ptr *int64 - RunePtr *rune - Float32Ptr *float32 - Float64Ptr *float64 - // Complex64Ptr *complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' - // Complex128Ptr *complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' - TimePtr *time.Time + Id int64 + StringPtr *string + StringPtr2 *string `xorm:"text"` + BoolPtr *bool + BytePtr *byte + UintPtr *uint + Uint8Ptr *uint8 + Uint16Ptr *uint16 + Uint32Ptr *uint32 + Uint64Ptr *uint64 + IntPtr *int + Int8Ptr *int8 + Int16Ptr *int16 + Int32Ptr *int32 + Int64Ptr *int64 + RunePtr *rune + Float32Ptr *float32 + Float64Ptr *float64 + // Complex64Ptr *complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr *complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + TimePtr *time.Time } type NullData2 struct { - Id int64 - StringPtr string - StringPtr2 string `xorm:"text"` - BoolPtr bool - BytePtr byte - UintPtr uint - Uint8Ptr uint8 - Uint16Ptr uint16 - Uint32Ptr uint32 - Uint64Ptr uint64 - IntPtr int - Int8Ptr int8 - Int16Ptr int16 - Int32Ptr int32 - Int64Ptr int64 - RunePtr rune - Float32Ptr float32 - Float64Ptr float64 - // Complex64Ptr complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' - // Complex128Ptr complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' - TimePtr time.Time + Id int64 + StringPtr string + StringPtr2 string `xorm:"text"` + BoolPtr bool + BytePtr byte + UintPtr uint + Uint8Ptr uint8 + Uint16Ptr uint16 + Uint32Ptr uint32 + Uint64Ptr uint64 + IntPtr int + Int8Ptr int8 + Int16Ptr int16 + Int32Ptr int32 + Int64Ptr int64 + RunePtr rune + Float32Ptr float32 + Float64Ptr float64 + // Complex64Ptr complex64 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + // Complex128Ptr complex128 // !nashtsai! XORM yet support complex128: 'json: unsupported type: complex128' + TimePtr time.Time } type NullData3 struct { - Id int64 - StringPtr *string + Id int64 + StringPtr *string } func testPointerData(engine *Engine, t *testing.T) { - err := engine.DropTables(&NullData{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&NullData{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } - nullData := NullData{ - StringPtr: new(string), - StringPtr2: new(string), - BoolPtr: new(bool), - BytePtr: new(byte), - UintPtr: new(uint), - Uint8Ptr: new(uint8), - Uint16Ptr: new(uint16), - Uint32Ptr: new(uint32), - Uint64Ptr: new(uint64), - IntPtr: new(int), - Int8Ptr: new(int8), - Int16Ptr: new(int16), - Int32Ptr: new(int32), - Int64Ptr: new(int64), - RunePtr: new(rune), - Float32Ptr: new(float32), - Float64Ptr: new(float64), - // Complex64Ptr: new(complex64), - // Complex128Ptr: new(complex128), - TimePtr: new(time.Time), - } + nullData := NullData{ + StringPtr: new(string), + StringPtr2: new(string), + BoolPtr: new(bool), + BytePtr: new(byte), + UintPtr: new(uint), + Uint8Ptr: new(uint8), + Uint16Ptr: new(uint16), + Uint32Ptr: new(uint32), + Uint64Ptr: new(uint64), + IntPtr: new(int), + Int8Ptr: new(int8), + Int16Ptr: new(int16), + Int32Ptr: new(int32), + Int64Ptr: new(int64), + RunePtr: new(rune), + Float32Ptr: new(float32), + Float64Ptr: new(float64), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), + TimePtr: new(time.Time), + } - *nullData.StringPtr = "abc" - *nullData.StringPtr2 = "123" - *nullData.BoolPtr = true - *nullData.BytePtr = 1 - *nullData.UintPtr = 1 - *nullData.Uint8Ptr = 1 - *nullData.Uint16Ptr = 1 - *nullData.Uint32Ptr = 1 - *nullData.Uint64Ptr = 1 - *nullData.IntPtr = -1 - *nullData.Int8Ptr = -1 - *nullData.Int16Ptr = -1 - *nullData.Int32Ptr = -1 - *nullData.Int64Ptr = -1 - *nullData.RunePtr = 1 - *nullData.Float32Ptr = -1.2 - *nullData.Float64Ptr = -1.1 - // *nullData.Complex64Ptr = 123456789012345678901234567890 - // *nullData.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 - *nullData.TimePtr = time.Now() + *nullData.StringPtr = "abc" + *nullData.StringPtr2 = "123" + *nullData.BoolPtr = true + *nullData.BytePtr = 1 + *nullData.UintPtr = 1 + *nullData.Uint8Ptr = 1 + *nullData.Uint16Ptr = 1 + *nullData.Uint32Ptr = 1 + *nullData.Uint64Ptr = 1 + *nullData.IntPtr = -1 + *nullData.Int8Ptr = -1 + *nullData.Int16Ptr = -1 + *nullData.Int32Ptr = -1 + *nullData.Int64Ptr = -1 + *nullData.RunePtr = 1 + *nullData.Float32Ptr = -1.2 + *nullData.Float64Ptr = -1.1 + // *nullData.Complex64Ptr = 123456789012345678901234567890 + // *nullData.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 + *nullData.TimePtr = time.Now() - cnt, err := engine.Insert(&nullData) - fmt.Println(nullData.Id) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } - if nullData.Id <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } - // verify get values - nullDataGet := NullData{} - has, err := engine.Id(nullData.Id).Get(&nullDataGet) - if err != nil { - t.Error(err) - panic(err) - } else if !has { - t.Error(errors.New("ID not found")) - } + // verify get values + nullDataGet := NullData{} + has, err := engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } - if *nullDataGet.StringPtr != *nullData.StringPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) - } + if *nullDataGet.StringPtr != *nullData.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) + } - if *nullDataGet.StringPtr2 != *nullData.StringPtr2 { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) - } + if *nullDataGet.StringPtr2 != *nullData.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) + } - if *nullDataGet.BoolPtr != *nullData.BoolPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) - } + if *nullDataGet.BoolPtr != *nullData.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) + } - if *nullDataGet.UintPtr != *nullData.UintPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) - } + if *nullDataGet.UintPtr != *nullData.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) + } - if *nullDataGet.Uint8Ptr != *nullData.Uint8Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) - } + if *nullDataGet.Uint8Ptr != *nullData.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) + } - if *nullDataGet.Uint16Ptr != *nullData.Uint16Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) - } + if *nullDataGet.Uint16Ptr != *nullData.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) + } - if *nullDataGet.Uint32Ptr != *nullData.Uint32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) - } + if *nullDataGet.Uint32Ptr != *nullData.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) + } - if *nullDataGet.Uint64Ptr != *nullData.Uint64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) - } + if *nullDataGet.Uint64Ptr != *nullData.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) + } - if *nullDataGet.IntPtr != *nullData.IntPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) - } + if *nullDataGet.IntPtr != *nullData.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) + } - if *nullDataGet.Int8Ptr != *nullData.Int8Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) - } + if *nullDataGet.Int8Ptr != *nullData.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) + } - if *nullDataGet.Int16Ptr != *nullData.Int16Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) - } + if *nullDataGet.Int16Ptr != *nullData.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) + } - if *nullDataGet.Int32Ptr != *nullData.Int32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) - } + if *nullDataGet.Int32Ptr != *nullData.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) + } - if *nullDataGet.Int64Ptr != *nullData.Int64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) - } + if *nullDataGet.Int64Ptr != *nullData.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) + } - if *nullDataGet.RunePtr != *nullData.RunePtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) - } + if *nullDataGet.RunePtr != *nullData.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) + } - if *nullDataGet.Float32Ptr != *nullData.Float32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) - } + if *nullDataGet.Float32Ptr != *nullData.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) + } - if *nullDataGet.Float64Ptr != *nullData.Float64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) - } + if *nullDataGet.Float64Ptr != *nullData.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) + } - // if *nullDataGet.Complex64Ptr != *nullData.Complex64Ptr { - // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) - // } + // if *nullDataGet.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } - // if *nullDataGet.Complex128Ptr != *nullData.Complex128Ptr { - // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) - // } + // if *nullDataGet.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } - /*if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) - } else { - // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver - fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr) - fmt.Println() - }*/ - // -- + /*if (*nullDataGet.TimePtr).Unix() != (*nullData.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullData.TimePtr) + fmt.Println() + }*/ + // -- - // using instance type should just work too - nullData2Get := NullData2{} + // using instance type should just work too + nullData2Get := NullData2{} - has, err = engine.Table("null_data").Id(nullData.Id).Get(&nullData2Get) - if err != nil { - t.Error(err) - panic(err) - } else if !has { - t.Error(errors.New("ID not found")) - } + has, err = engine.Table("null_data").Id(nullData.Id).Get(&nullData2Get) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } - if nullData2Get.StringPtr != *nullData.StringPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr))) - } + if nullData2Get.StringPtr != *nullData.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr))) + } - if nullData2Get.StringPtr2 != *nullData.StringPtr2 { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr2))) - } + if nullData2Get.StringPtr2 != *nullData.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.StringPtr2))) + } - if nullData2Get.BoolPtr != *nullData.BoolPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", nullData2Get.BoolPtr))) - } + if nullData2Get.BoolPtr != *nullData.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", nullData2Get.BoolPtr))) + } - if nullData2Get.UintPtr != *nullData.UintPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.UintPtr))) - } + if nullData2Get.UintPtr != *nullData.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.UintPtr))) + } - if nullData2Get.Uint8Ptr != *nullData.Uint8Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint8Ptr))) - } + if nullData2Get.Uint8Ptr != *nullData.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint8Ptr))) + } - if nullData2Get.Uint16Ptr != *nullData.Uint16Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint16Ptr))) - } + if nullData2Get.Uint16Ptr != *nullData.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint16Ptr))) + } - if nullData2Get.Uint32Ptr != *nullData.Uint32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint32Ptr))) - } + if nullData2Get.Uint32Ptr != *nullData.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint32Ptr))) + } - if nullData2Get.Uint64Ptr != *nullData.Uint64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint64Ptr))) - } + if nullData2Get.Uint64Ptr != *nullData.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Uint64Ptr))) + } - if nullData2Get.IntPtr != *nullData.IntPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.IntPtr))) - } + if nullData2Get.IntPtr != *nullData.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.IntPtr))) + } - if nullData2Get.Int8Ptr != *nullData.Int8Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int8Ptr))) - } + if nullData2Get.Int8Ptr != *nullData.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int8Ptr))) + } - if nullData2Get.Int16Ptr != *nullData.Int16Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int16Ptr))) - } + if nullData2Get.Int16Ptr != *nullData.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int16Ptr))) + } - if nullData2Get.Int32Ptr != *nullData.Int32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int32Ptr))) - } + if nullData2Get.Int32Ptr != *nullData.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int32Ptr))) + } - if nullData2Get.Int64Ptr != *nullData.Int64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int64Ptr))) - } + if nullData2Get.Int64Ptr != *nullData.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Int64Ptr))) + } - if nullData2Get.RunePtr != *nullData.RunePtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.RunePtr))) - } + if nullData2Get.RunePtr != *nullData.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.RunePtr))) + } - if nullData2Get.Float32Ptr != *nullData.Float32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float32Ptr))) - } + if nullData2Get.Float32Ptr != *nullData.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float32Ptr))) + } - if nullData2Get.Float64Ptr != *nullData.Float64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr))) - } + if nullData2Get.Float64Ptr != *nullData.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Float64Ptr))) + } - // if nullData2Get.Complex64Ptr != *nullData.Complex64Ptr { - // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex64Ptr))) - // } + // if nullData2Get.Complex64Ptr != *nullData.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex64Ptr))) + // } - // if nullData2Get.Complex128Ptr != *nullData.Complex128Ptr { - // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex128Ptr))) - // } + // if nullData2Get.Complex128Ptr != *nullData.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", nullData2Get.Complex128Ptr))) + // } - /*if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr))) - } else { - // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver - fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr) - fmt.Println() - }*/ - // -- + /*if nullData2Get.TimePtr.Unix() != (*nullData.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + fmt.Printf("time value: [%v]:[%v]", nullData2Get.TimePtr, *nullData.TimePtr) + fmt.Println() + }*/ + // -- } func testNullValue(engine *Engine, t *testing.T) { - err := engine.DropTables(&NullData{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = engine.CreateTables(&NullData{}) - if err != nil { - t.Error(err) - panic(err) - } - - nullData := NullData{} - - cnt, err := engine.Insert(&nullData) - fmt.Println(nullData.Id) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } - if nullData.Id <= 0 { - err = errors.New("not return id error") - t.Error(err) - panic(err) - } - - nullDataGet := NullData{} - - has, err := engine.Id(nullData.Id).Get(&nullDataGet) - if err != nil { - t.Error(err) - panic(err) - } else if !has { - t.Error(errors.New("ID not found")) - } - - if nullDataGet.StringPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) - } - - if nullDataGet.StringPtr2 != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) - } - - if nullDataGet.BoolPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) - } - - if nullDataGet.UintPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) - } - - if nullDataGet.Uint8Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) - } - - if nullDataGet.Uint16Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) - } - - if nullDataGet.Uint32Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) - } - - if nullDataGet.Uint64Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) - } - - if nullDataGet.IntPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) - } - - if nullDataGet.Int8Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) - } - - if nullDataGet.Int16Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) - } - - if nullDataGet.Int32Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) - } - - if nullDataGet.Int64Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) - } - - if nullDataGet.RunePtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) - } - - if nullDataGet.Float32Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) - } - - if nullDataGet.Float64Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) - } - - // if nullDataGet.Complex64Ptr != nil { - // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex64Ptr))) - // } - - // if nullDataGet.Complex128Ptr != nil { - // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex128Ptr))) - // } - - if nullDataGet.TimePtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) - } - - nullDataUpdate := NullData{ - StringPtr: new(string), - StringPtr2: new(string), - BoolPtr: new(bool), - BytePtr: new(byte), - UintPtr: new(uint), - Uint8Ptr: new(uint8), - Uint16Ptr: new(uint16), - Uint32Ptr: new(uint32), - Uint64Ptr: new(uint64), - IntPtr: new(int), - Int8Ptr: new(int8), - Int16Ptr: new(int16), - Int32Ptr: new(int32), - Int64Ptr: new(int64), - RunePtr: new(rune), - Float32Ptr: new(float32), - Float64Ptr: new(float64), - // Complex64Ptr: new(complex64), - // Complex128Ptr: new(complex128), - TimePtr: new(time.Time), - } - - *nullDataUpdate.StringPtr = "abc" - *nullDataUpdate.StringPtr2 = "123" - *nullDataUpdate.BoolPtr = true - *nullDataUpdate.BytePtr = 1 - *nullDataUpdate.UintPtr = 1 - *nullDataUpdate.Uint8Ptr = 1 - *nullDataUpdate.Uint16Ptr = 1 - *nullDataUpdate.Uint32Ptr = 1 - *nullDataUpdate.Uint64Ptr = 1 - *nullDataUpdate.IntPtr = -1 - *nullDataUpdate.Int8Ptr = -1 - *nullDataUpdate.Int16Ptr = -1 - *nullDataUpdate.Int32Ptr = -1 - *nullDataUpdate.Int64Ptr = -1 - *nullDataUpdate.RunePtr = 1 - *nullDataUpdate.Float32Ptr = -1.2 - *nullDataUpdate.Float64Ptr = -1.1 - // *nullDataUpdate.Complex64Ptr = 123456789012345678901234567890 - // *nullDataUpdate.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 - *nullDataUpdate.TimePtr = time.Now() - - cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) - if err != nil { - t.Error(err) - panic(err) - } else if cnt != 1 { - t.Error(errors.New("update count == 0, how can this happen!?")) - return - } - - // verify get values - nullDataGet = NullData{} - has, err = engine.Id(nullData.Id).Get(&nullDataGet) - if err != nil { - t.Error(err) - return - } else if !has { - t.Error(errors.New("ID not found")) - return - } - - if *nullDataGet.StringPtr != *nullDataUpdate.StringPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) - } - - if *nullDataGet.StringPtr2 != *nullDataUpdate.StringPtr2 { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) - } - - if *nullDataGet.BoolPtr != *nullDataUpdate.BoolPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) - } - - if *nullDataGet.UintPtr != *nullDataUpdate.UintPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) - } - - if *nullDataGet.Uint8Ptr != *nullDataUpdate.Uint8Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) - } - - if *nullDataGet.Uint16Ptr != *nullDataUpdate.Uint16Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) - } - - if *nullDataGet.Uint32Ptr != *nullDataUpdate.Uint32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) - } - - if *nullDataGet.Uint64Ptr != *nullDataUpdate.Uint64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) - } - - if *nullDataGet.IntPtr != *nullDataUpdate.IntPtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) - } - - if *nullDataGet.Int8Ptr != *nullDataUpdate.Int8Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) - } - - if *nullDataGet.Int16Ptr != *nullDataUpdate.Int16Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) - } - - if *nullDataGet.Int32Ptr != *nullDataUpdate.Int32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) - } - - if *nullDataGet.Int64Ptr != *nullDataUpdate.Int64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) - } - - if *nullDataGet.RunePtr != *nullDataUpdate.RunePtr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) - } - - if *nullDataGet.Float32Ptr != *nullDataUpdate.Float32Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) - } - - if *nullDataGet.Float64Ptr != *nullDataUpdate.Float64Ptr { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) - } - - // if *nullDataGet.Complex64Ptr != *nullDataUpdate.Complex64Ptr { - // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) - // } - - // if *nullDataGet.Complex128Ptr != *nullDataUpdate.Complex128Ptr { - // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) - // } - - /*if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { - t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) - } else { - // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver - fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) - fmt.Println() - }*/ - // -- - - // update to null values - /*nullDataUpdate = NullData{} - - cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) - if err != nil { - t.Error(err) - panic(err) - } else if cnt != 1 { - t.Error(errors.New("update count == 0, how can this happen!?")) - return - }*/ - - // verify get values - /*nullDataGet = NullData{} - has, err = engine.Id(nullData.Id).Get(&nullDataGet) - if err != nil { - t.Error(err) - return - } else if !has { - t.Error(errors.New("ID not found")) - return - } - - fmt.Printf("%+v", nullDataGet) - fmt.Println() - - if nullDataGet.StringPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) - } - - if nullDataGet.StringPtr2 != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) - } - - if nullDataGet.BoolPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) - } - - if nullDataGet.UintPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) - } - - if nullDataGet.Uint8Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) - } - - if nullDataGet.Uint16Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) - } - - if nullDataGet.Uint32Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) - } - - if nullDataGet.Uint64Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) - } - - if nullDataGet.IntPtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) - } - - if nullDataGet.Int8Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) - } - - if nullDataGet.Int16Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) - } - - if nullDataGet.Int32Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) - } - - if nullDataGet.Int64Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) - } - - if nullDataGet.RunePtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) - } - - if nullDataGet.Float32Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) - } - - if nullDataGet.Float64Ptr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) - } - - // if nullDataGet.Complex64Ptr != nil { - // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) - // } - - // if nullDataGet.Complex128Ptr != nil { - // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) - // } - - if nullDataGet.TimePtr != nil { - t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) - }*/ - // -- + err := engine.DropTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&NullData{}) + if err != nil { + t.Error(err) + panic(err) + } + + nullData := NullData{} + + cnt, err := engine.Insert(&nullData) + fmt.Println(nullData.Id) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert not returned 1") + t.Error(err) + panic(err) + return + } + if nullData.Id <= 0 { + err = errors.New("not return id error") + t.Error(err) + panic(err) + } + + nullDataGet := NullData{} + + has, err := engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + panic(err) + } else if !has { + t.Error(errors.New("ID not found")) + } + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Complex128Ptr))) + // } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + } + + nullDataUpdate := NullData{ + StringPtr: new(string), + StringPtr2: new(string), + BoolPtr: new(bool), + BytePtr: new(byte), + UintPtr: new(uint), + Uint8Ptr: new(uint8), + Uint16Ptr: new(uint16), + Uint32Ptr: new(uint32), + Uint64Ptr: new(uint64), + IntPtr: new(int), + Int8Ptr: new(int8), + Int16Ptr: new(int16), + Int32Ptr: new(int32), + Int64Ptr: new(int64), + RunePtr: new(rune), + Float32Ptr: new(float32), + Float64Ptr: new(float64), + // Complex64Ptr: new(complex64), + // Complex128Ptr: new(complex128), + TimePtr: new(time.Time), + } + + *nullDataUpdate.StringPtr = "abc" + *nullDataUpdate.StringPtr2 = "123" + *nullDataUpdate.BoolPtr = true + *nullDataUpdate.BytePtr = 1 + *nullDataUpdate.UintPtr = 1 + *nullDataUpdate.Uint8Ptr = 1 + *nullDataUpdate.Uint16Ptr = 1 + *nullDataUpdate.Uint32Ptr = 1 + *nullDataUpdate.Uint64Ptr = 1 + *nullDataUpdate.IntPtr = -1 + *nullDataUpdate.Int8Ptr = -1 + *nullDataUpdate.Int16Ptr = -1 + *nullDataUpdate.Int32Ptr = -1 + *nullDataUpdate.Int64Ptr = -1 + *nullDataUpdate.RunePtr = 1 + *nullDataUpdate.Float32Ptr = -1.2 + *nullDataUpdate.Float64Ptr = -1.1 + // *nullDataUpdate.Complex64Ptr = 123456789012345678901234567890 + // *nullDataUpdate.Complex128Ptr = 123456789012345678901234567890123456789012345678901234567890 + *nullDataUpdate.TimePtr = time.Now() + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + } + + // verify get values + nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } + + if *nullDataGet.StringPtr != *nullDataUpdate.StringPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr))) + } + + if *nullDataGet.StringPtr2 != *nullDataUpdate.StringPtr2 { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.StringPtr2))) + } + + if *nullDataGet.BoolPtr != *nullDataUpdate.BoolPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%t]", *nullDataGet.BoolPtr))) + } + + if *nullDataGet.UintPtr != *nullDataUpdate.UintPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.UintPtr))) + } + + if *nullDataGet.Uint8Ptr != *nullDataUpdate.Uint8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint8Ptr))) + } + + if *nullDataGet.Uint16Ptr != *nullDataUpdate.Uint16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint16Ptr))) + } + + if *nullDataGet.Uint32Ptr != *nullDataUpdate.Uint32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint32Ptr))) + } + + if *nullDataGet.Uint64Ptr != *nullDataUpdate.Uint64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Uint64Ptr))) + } + + if *nullDataGet.IntPtr != *nullDataUpdate.IntPtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.IntPtr))) + } + + if *nullDataGet.Int8Ptr != *nullDataUpdate.Int8Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int8Ptr))) + } + + if *nullDataGet.Int16Ptr != *nullDataUpdate.Int16Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int16Ptr))) + } + + if *nullDataGet.Int32Ptr != *nullDataUpdate.Int32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int32Ptr))) + } + + if *nullDataGet.Int64Ptr != *nullDataUpdate.Int64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Int64Ptr))) + } + + if *nullDataGet.RunePtr != *nullDataUpdate.RunePtr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.RunePtr))) + } + + if *nullDataGet.Float32Ptr != *nullDataUpdate.Float32Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float32Ptr))) + } + + if *nullDataGet.Float64Ptr != *nullDataUpdate.Float64Ptr { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Float64Ptr))) + } + + // if *nullDataGet.Complex64Ptr != *nullDataUpdate.Complex64Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex64Ptr))) + // } + + // if *nullDataGet.Complex128Ptr != *nullDataUpdate.Complex128Ptr { + // t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]", *nullDataGet.Complex128Ptr))) + // } + + /*if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { + t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) + } else { + // !nashtsai! mymysql driver will failed this test case, due the time is roundup to nearest second, I would considered this is a bug in mymysql driver + fmt.Printf("time value: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr) + fmt.Println() + }*/ + // -- + + // update to null values + /*nullDataUpdate = NullData{} + + cnt, err = engine.Id(nullData.Id).Update(&nullDataUpdate) + if err != nil { + t.Error(err) + panic(err) + } else if cnt != 1 { + t.Error(errors.New("update count == 0, how can this happen!?")) + return + }*/ + + // verify get values + /*nullDataGet = NullData{} + has, err = engine.Id(nullData.Id).Get(&nullDataGet) + if err != nil { + t.Error(err) + return + } else if !has { + t.Error(errors.New("ID not found")) + return + } + + fmt.Printf("%+v", nullDataGet) + fmt.Println() + + if nullDataGet.StringPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr))) + } + + if nullDataGet.StringPtr2 != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.StringPtr2))) + } + + if nullDataGet.BoolPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%t]", *nullDataGet.BoolPtr))) + } + + if nullDataGet.UintPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.UintPtr))) + } + + if nullDataGet.Uint8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint8Ptr))) + } + + if nullDataGet.Uint16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint16Ptr))) + } + + if nullDataGet.Uint32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint32Ptr))) + } + + if nullDataGet.Uint64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Uint64Ptr))) + } + + if nullDataGet.IntPtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.IntPtr))) + } + + if nullDataGet.Int8Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int8Ptr))) + } + + if nullDataGet.Int16Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int16Ptr))) + } + + if nullDataGet.Int32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int32Ptr))) + } + + if nullDataGet.Int64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Int64Ptr))) + } + + if nullDataGet.RunePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.RunePtr))) + } + + if nullDataGet.Float32Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float32Ptr))) + } + + if nullDataGet.Float64Ptr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + } + + // if nullDataGet.Complex64Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + // if nullDataGet.Complex128Ptr != nil { + // t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.Float64Ptr))) + // } + + if nullDataGet.TimePtr != nil { + t.Error(errors.New(fmt.Sprintf("not null value: [%v]", *nullDataGet.TimePtr))) + }*/ + // -- } type CompositeKey struct { - Id1 int64 `xorm:"id1 pk"` - Id2 int64 `xorm:"id2 pk"` - UpdateStr string + Id1 int64 `xorm:"id1 pk"` + Id2 int64 `xorm:"id2 pk"` + UpdateStr string } func testCompositeKey(engine *Engine, t *testing.T) { - err := engine.DropTables(&CompositeKey{}) - if err != nil { - t.Error(err) - panic(err) - } + err := engine.DropTables(&CompositeKey{}) + if err != nil { + t.Error(err) + panic(err) + } - err = engine.CreateTables(&CompositeKey{}) - if err != nil { - t.Error(err) - panic(err) - } + err = engine.CreateTables(&CompositeKey{}) + if err != nil { + t.Error(err) + panic(err) + } - cnt, err := engine.Insert(&CompositeKey{11, 22, ""}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("failed to insert CompositeKey{11, 22}")) - } + cnt, err := engine.Insert(&CompositeKey{11, 22, ""}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("failed to insert CompositeKey{11, 22}")) + } - cnt, err = engine.Insert(&CompositeKey{11, 22, ""}) - if err == nil || cnt == 1 { - t.Error(errors.New("inserted CompositeKey{11, 22}")) - } + cnt, err = engine.Insert(&CompositeKey{11, 22, ""}) + if err == nil || cnt == 1 { + t.Error(errors.New("inserted CompositeKey{11, 22}")) + } - var compositeKeyVal CompositeKey - has, err := engine.Id(PK{11, 22}).Get(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get CompositeKey{11, 22}")) - } + var compositeKeyVal CompositeKey + has, err := engine.Id(PK{11, 22}).Get(&compositeKeyVal) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get CompositeKey{11, 22}")) + } - // test passing PK ptr, this test seem failed withCache - has, err = engine.Id(&PK{11, 22}).Get(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if !has { - t.Error(errors.New("can't get CompositeKey{11, 22}")) - } + // test passing PK ptr, this test seem failed withCache + has, err = engine.Id(&PK{11, 22}).Get(&compositeKeyVal) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get CompositeKey{11, 22}")) + } - compositeKeyVal = CompositeKey{UpdateStr:"test1"} - cnt, err = engine.Id(PK{11, 22}).Update(&compositeKeyVal) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't update CompositeKey{11, 22}")) - } + compositeKeyVal = CompositeKey{UpdateStr: "test1"} + cnt, err = engine.Id(PK{11, 22}).Update(&compositeKeyVal) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't update CompositeKey{11, 22}")) + } - cnt, err = engine.Id(PK{11, 22}).Delete(&CompositeKey{}) - if err != nil { - t.Error(err) - } else if cnt != 1 { - t.Error(errors.New("can't delete CompositeKey{11, 22}")) - } + cnt, err = engine.Id(PK{11, 22}).Delete(&CompositeKey{}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't delete CompositeKey{11, 22}")) + } } - func testAll(engine *Engine, t *testing.T) { - fmt.Println("-------------- directCreateTable --------------") - directCreateTable(engine, t) - fmt.Println("-------------- insert --------------") - insert(engine, t) - fmt.Println("-------------- query --------------") - testQuery(engine, t) - fmt.Println("-------------- exec --------------") - exec(engine, t) - fmt.Println("-------------- insertAutoIncr --------------") - insertAutoIncr(engine, t) - fmt.Println("-------------- insertMulti --------------") - insertMulti(engine, t) - fmt.Println("-------------- insertTwoTable --------------") - insertTwoTable(engine, t) - fmt.Println("-------------- update --------------") - update(engine, t) - fmt.Println("-------------- testdelete --------------") - testdelete(engine, t) - fmt.Println("-------------- get --------------") - get(engine, t) - fmt.Println("-------------- cascadeGet --------------") - cascadeGet(engine, t) - fmt.Println("-------------- find --------------") - find(engine, t) - fmt.Println("-------------- find2 --------------") - find2(engine, t) - fmt.Println("-------------- findMap --------------") - findMap(engine, t) - fmt.Println("-------------- findMap2 --------------") - findMap2(engine, t) - fmt.Println("-------------- count --------------") - count(engine, t) - fmt.Println("-------------- where --------------") - where(engine, t) - fmt.Println("-------------- in --------------") - in(engine, t) - fmt.Println("-------------- limit --------------") - limit(engine, t) - fmt.Println("-------------- order --------------") - order(engine, t) - fmt.Println("-------------- join --------------") - join(engine, t) - fmt.Println("-------------- having --------------") - having(engine, t) + fmt.Println("-------------- directCreateTable --------------") + directCreateTable(engine, t) + fmt.Println("-------------- insert --------------") + insert(engine, t) + fmt.Println("-------------- query --------------") + testQuery(engine, t) + fmt.Println("-------------- exec --------------") + exec(engine, t) + fmt.Println("-------------- insertAutoIncr --------------") + insertAutoIncr(engine, t) + fmt.Println("-------------- insertMulti --------------") + insertMulti(engine, t) + fmt.Println("-------------- insertTwoTable --------------") + insertTwoTable(engine, t) + fmt.Println("-------------- update --------------") + update(engine, t) + fmt.Println("-------------- testdelete --------------") + testdelete(engine, t) + fmt.Println("-------------- get --------------") + get(engine, t) + fmt.Println("-------------- cascadeGet --------------") + cascadeGet(engine, t) + fmt.Println("-------------- find --------------") + find(engine, t) + fmt.Println("-------------- find2 --------------") + find2(engine, t) + fmt.Println("-------------- findMap --------------") + findMap(engine, t) + fmt.Println("-------------- findMap2 --------------") + findMap2(engine, t) + fmt.Println("-------------- count --------------") + count(engine, t) + fmt.Println("-------------- where --------------") + where(engine, t) + fmt.Println("-------------- in --------------") + in(engine, t) + fmt.Println("-------------- limit --------------") + limit(engine, t) + fmt.Println("-------------- order --------------") + order(engine, t) + fmt.Println("-------------- join --------------") + join(engine, t) + fmt.Println("-------------- having --------------") + having(engine, t) } func testAll2(engine *Engine, t *testing.T) { - fmt.Println("-------------- combineTransaction --------------") - combineTransaction(engine, t) - fmt.Println("-------------- table --------------") - table(engine, t) - fmt.Println("-------------- createMultiTables --------------") - createMultiTables(engine, t) - fmt.Println("-------------- tableOp --------------") - tableOp(engine, t) - fmt.Println("-------------- testCols --------------") - testCols(engine, t) - fmt.Println("-------------- testCharst --------------") - testCharst(engine, t) - fmt.Println("-------------- testStoreEngine --------------") - testStoreEngine(engine, t) - fmt.Println("-------------- testExtends --------------") - testExtends(engine, t) - fmt.Println("-------------- testColTypes --------------") - testColTypes(engine, t) - fmt.Println("-------------- testCustomType --------------") - testCustomType(engine, t) - fmt.Println("-------------- testCreatedAndUpdated --------------") - testCreatedAndUpdated(engine, t) - fmt.Println("-------------- testIndexAndUnique --------------") - testIndexAndUnique(engine, t) - fmt.Println("-------------- testIntId --------------") - //testIntId(engine, t) - fmt.Println("-------------- testInt32Id --------------") - //testInt32Id(engine, t) - fmt.Println("-------------- testMetaInfo --------------") - testMetaInfo(engine, t) - fmt.Println("-------------- testIterate --------------") - testIterate(engine, t) - fmt.Println("-------------- testStrangeName --------------") - testStrangeName(engine, t) - fmt.Println("-------------- testVersion --------------") - testVersion(engine, t) - fmt.Println("-------------- testDistinct --------------") - testDistinct(engine, t) - fmt.Println("-------------- testUseBool --------------") - testUseBool(engine, t) - fmt.Println("-------------- testBool --------------") - testBool(engine, t) - fmt.Println("-------------- testTime --------------") - testTime(engine, t) - fmt.Println("-------------- testPrefixTableName --------------") - testPrefixTableName(engine, t) - fmt.Println("-------------- testCreatedUpdated --------------") - testCreatedUpdated(engine, t) - fmt.Println("-------------- processors --------------") - testProcessors(engine, t) - fmt.Println("-------------- transaction --------------") - transaction(engine, t) + fmt.Println("-------------- combineTransaction --------------") + combineTransaction(engine, t) + fmt.Println("-------------- table --------------") + table(engine, t) + fmt.Println("-------------- createMultiTables --------------") + createMultiTables(engine, t) + fmt.Println("-------------- tableOp --------------") + tableOp(engine, t) + fmt.Println("-------------- testCols --------------") + testCols(engine, t) + fmt.Println("-------------- testCharst --------------") + testCharst(engine, t) + fmt.Println("-------------- testStoreEngine --------------") + testStoreEngine(engine, t) + fmt.Println("-------------- testExtends --------------") + testExtends(engine, t) + fmt.Println("-------------- testColTypes --------------") + testColTypes(engine, t) + fmt.Println("-------------- testCustomType --------------") + testCustomType(engine, t) + fmt.Println("-------------- testCreatedAndUpdated --------------") + testCreatedAndUpdated(engine, t) + fmt.Println("-------------- testIndexAndUnique --------------") + testIndexAndUnique(engine, t) + fmt.Println("-------------- testIntId --------------") + //testIntId(engine, t) + fmt.Println("-------------- testInt32Id --------------") + //testInt32Id(engine, t) + fmt.Println("-------------- testMetaInfo --------------") + testMetaInfo(engine, t) + fmt.Println("-------------- testIterate --------------") + testIterate(engine, t) + fmt.Println("-------------- testStrangeName --------------") + testStrangeName(engine, t) + fmt.Println("-------------- testVersion --------------") + testVersion(engine, t) + fmt.Println("-------------- testDistinct --------------") + testDistinct(engine, t) + fmt.Println("-------------- testUseBool --------------") + testUseBool(engine, t) + fmt.Println("-------------- testBool --------------") + testBool(engine, t) + fmt.Println("-------------- testTime --------------") + testTime(engine, t) + fmt.Println("-------------- testPrefixTableName --------------") + testPrefixTableName(engine, t) + fmt.Println("-------------- testCreatedUpdated --------------") + testCreatedUpdated(engine, t) + fmt.Println("-------------- processors --------------") + testProcessors(engine, t) + fmt.Println("-------------- transaction --------------") + transaction(engine, t) } // !nash! the 3rd set of the test is intended for non-cache enabled engine func testAll3(engine *Engine, t *testing.T) { - fmt.Println("-------------- processors TX --------------") - testProcessorsTx(engine, t) - fmt.Println("-------------- insert pointer data --------------") - testPointerData(engine, t) - fmt.Println("-------------- insert null data --------------") - testNullValue(engine, t) - fmt.Println("-------------- testCompositeKey --------------") - testCompositeKey(engine, t) + fmt.Println("-------------- processors TX --------------") + testProcessorsTx(engine, t) + fmt.Println("-------------- insert pointer data --------------") + testPointerData(engine, t) + fmt.Println("-------------- insert null data --------------") + testNullValue(engine, t) + fmt.Println("-------------- testCompositeKey --------------") + testCompositeKey(engine, t) } diff --git a/cache.go b/cache.go index 10ec444c..e1ccc0d1 100644 --- a/cache.go +++ b/cache.go @@ -1,131 +1,131 @@ package xorm import ( - "container/list" - "errors" - "fmt" - "strconv" - "strings" - "sync" - "time" + "container/list" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" ) const ( - // default cache expired time - CacheExpired = 60 * time.Minute - // not use now - CacheMaxMemory = 256 - // evey ten minutes to clear all expired nodes - CacheGcInterval = 10 * time.Minute - // each time when gc to removed max nodes - CacheGcMaxRemoved = 20 + // default cache expired time + CacheExpired = 60 * time.Minute + // not use now + CacheMaxMemory = 256 + // evey ten minutes to clear all expired nodes + CacheGcInterval = 10 * time.Minute + // each time when gc to removed max nodes + CacheGcMaxRemoved = 20 ) // CacheStore is a interface to store cache type CacheStore interface { - Put(key, value interface{}) error - Get(key interface{}) (interface{}, error) - Del(key interface{}) error + Put(key, value interface{}) error + Get(key interface{}) (interface{}, error) + Del(key interface{}) error } // MemoryStore implements CacheStore provide local machine // memory store type MemoryStore struct { - store map[interface{}]interface{} - mutex sync.RWMutex + store map[interface{}]interface{} + mutex sync.RWMutex } func NewMemoryStore() *MemoryStore { - return &MemoryStore{store: make(map[interface{}]interface{})} + return &MemoryStore{store: make(map[interface{}]interface{})} } func (s *MemoryStore) Put(key, value interface{}) error { - s.mutex.Lock() - defer s.mutex.Unlock() - s.store[key] = value - return nil + s.mutex.Lock() + defer s.mutex.Unlock() + s.store[key] = value + return nil } func (s *MemoryStore) Get(key interface{}) (interface{}, error) { - s.mutex.RLock() - defer s.mutex.RUnlock() - if v, ok := s.store[key]; ok { - return v, nil - } + s.mutex.RLock() + defer s.mutex.RUnlock() + if v, ok := s.store[key]; ok { + return v, nil + } - return nil, ErrNotExist + return nil, ErrNotExist } func (s *MemoryStore) Del(key interface{}) error { - s.mutex.Lock() - defer s.mutex.Unlock() - delete(s.store, key) - return nil + s.mutex.Lock() + defer s.mutex.Unlock() + delete(s.store, key) + return nil } // Cacher is an interface to provide cache type Cacher interface { - GetIds(tableName, sql string) interface{} - GetBean(tableName string, id int64) interface{} - PutIds(tableName, sql string, ids interface{}) - PutBean(tableName string, id int64, obj interface{}) - DelIds(tableName, sql string) - DelBean(tableName string, id int64) - ClearIds(tableName string) - ClearBeans(tableName string) + GetIds(tableName, sql string) interface{} + GetBean(tableName string, id int64) interface{} + PutIds(tableName, sql string, ids interface{}) + PutBean(tableName string, id int64, obj interface{}) + DelIds(tableName, sql string) + DelBean(tableName string, id int64) + ClearIds(tableName string) + ClearBeans(tableName string) } type idNode struct { - tbName string - id int64 - lastVisit time.Time + tbName string + id int64 + lastVisit time.Time } type sqlNode struct { - tbName string - sql string - lastVisit time.Time + tbName string + sql string + lastVisit time.Time } func newIdNode(tbName string, id int64) *idNode { - return &idNode{tbName, id, time.Now()} + return &idNode{tbName, id, time.Now()} } func newSqlNode(tbName, sql string) *sqlNode { - return &sqlNode{tbName, sql, time.Now()} + return &sqlNode{tbName, sql, time.Now()} } // LRUCacher implements Cacher according to LRU algorithm type LRUCacher struct { - idList *list.List - sqlList *list.List - idIndex map[string]map[interface{}]*list.Element - sqlIndex map[string]map[interface{}]*list.Element - store CacheStore - Max int - mutex sync.Mutex - Expired time.Duration - maxSize int - GcInterval time.Duration + idList *list.List + sqlList *list.List + idIndex map[string]map[interface{}]*list.Element + sqlIndex map[string]map[interface{}]*list.Element + store CacheStore + Max int + mutex sync.Mutex + Expired time.Duration + maxSize int + GcInterval time.Duration } func newLRUCacher(store CacheStore, expired time.Duration, maxSize int, max int) *LRUCacher { - cacher := &LRUCacher{store: store, idList: list.New(), - sqlList: list.New(), Expired: expired, maxSize: maxSize, - GcInterval: CacheGcInterval, Max: max, - sqlIndex: make(map[string]map[interface{}]*list.Element), - idIndex: make(map[string]map[interface{}]*list.Element), - } - cacher.RunGC() - return cacher + cacher := &LRUCacher{store: store, idList: list.New(), + sqlList: list.New(), Expired: expired, maxSize: maxSize, + GcInterval: CacheGcInterval, Max: max, + sqlIndex: make(map[string]map[interface{}]*list.Element), + idIndex: make(map[string]map[interface{}]*list.Element), + } + cacher.RunGC() + return cacher } func NewLRUCacher(store CacheStore, max int) *LRUCacher { - return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) + return newLRUCacher(store, CacheExpired, CacheMaxMemory, max) } func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher { - return newLRUCacher(store, expired, 0, max) + return newLRUCacher(store, expired, 0, max) } //func NewLRUCacher3(store CacheStore, expired time.Duration, maxSize int) *LRUCacher { @@ -134,262 +134,262 @@ func NewLRUCacher2(store CacheStore, expired time.Duration, max int) *LRUCacher // RunGC run once every m.GcInterval func (m *LRUCacher) RunGC() { - time.AfterFunc(m.GcInterval, func() { - m.RunGC() - m.GC() - }) + time.AfterFunc(m.GcInterval, func() { + m.RunGC() + m.GC() + }) } // GC check ids lit and sql list to remove all element expired func (m *LRUCacher) GC() { - //fmt.Println("begin gc ...") - //defer fmt.Println("end gc ...") - m.mutex.Lock() - defer m.mutex.Unlock() - var removedNum int - for e := m.idList.Front(); e != nil; { - if removedNum <= CacheGcMaxRemoved && - time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { - removedNum++ - next := e.Next() - //fmt.Println("removing ...", e.Value) - node := e.Value.(*idNode) - m.delBean(node.tbName, node.id) - e = next - } else { - //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len()) - break - } - } + //fmt.Println("begin gc ...") + //defer fmt.Println("end gc ...") + m.mutex.Lock() + defer m.mutex.Unlock() + var removedNum int + for e := m.idList.Front(); e != nil; { + if removedNum <= CacheGcMaxRemoved && + time.Now().Sub(e.Value.(*idNode).lastVisit) > m.Expired { + removedNum++ + next := e.Next() + //fmt.Println("removing ...", e.Value) + node := e.Value.(*idNode) + m.delBean(node.tbName, node.id) + e = next + } else { + //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.idList.Len()) + break + } + } - removedNum = 0 - for e := m.sqlList.Front(); e != nil; { - if removedNum <= CacheGcMaxRemoved && - time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { - removedNum++ - next := e.Next() - //fmt.Println("removing ...", e.Value) - node := e.Value.(*sqlNode) - m.delIds(node.tbName, node.sql) - e = next - } else { - //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len()) - break - } - } + removedNum = 0 + for e := m.sqlList.Front(); e != nil; { + if removedNum <= CacheGcMaxRemoved && + time.Now().Sub(e.Value.(*sqlNode).lastVisit) > m.Expired { + removedNum++ + next := e.Next() + //fmt.Println("removing ...", e.Value) + node := e.Value.(*sqlNode) + m.delIds(node.tbName, node.sql) + e = next + } else { + //fmt.Printf("removing %d cache nodes ..., left %d\n", removedNum, m.sqlList.Len()) + break + } + } } // Get all bean's ids according to sql and parameter from cache func (m *LRUCacher) GetIds(tableName, sql string) interface{} { - m.mutex.Lock() - defer m.mutex.Unlock() - if _, ok := m.sqlIndex[tableName]; !ok { - m.sqlIndex[tableName] = make(map[interface{}]*list.Element) - } - if v, err := m.store.Get(sql); err == nil { - if el, ok := m.sqlIndex[tableName][sql]; !ok { - el = m.sqlList.PushBack(newSqlNode(tableName, sql)) - m.sqlIndex[tableName][sql] = el - } else { - lastTime := el.Value.(*sqlNode).lastVisit - // if expired, remove the node and return nil - if time.Now().Sub(lastTime) > m.Expired { - m.delIds(tableName, sql) - return nil - } - m.sqlList.MoveToBack(el) - el.Value.(*sqlNode).lastVisit = time.Now() - } - return v - } else { - m.delIds(tableName, sql) - } + m.mutex.Lock() + defer m.mutex.Unlock() + if _, ok := m.sqlIndex[tableName]; !ok { + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + } + if v, err := m.store.Get(sql); err == nil { + if el, ok := m.sqlIndex[tableName][sql]; !ok { + el = m.sqlList.PushBack(newSqlNode(tableName, sql)) + m.sqlIndex[tableName][sql] = el + } else { + lastTime := el.Value.(*sqlNode).lastVisit + // if expired, remove the node and return nil + if time.Now().Sub(lastTime) > m.Expired { + m.delIds(tableName, sql) + return nil + } + m.sqlList.MoveToBack(el) + el.Value.(*sqlNode).lastVisit = time.Now() + } + return v + } else { + m.delIds(tableName, sql) + } - return nil + return nil } // Get bean according tableName and id from cache func (m *LRUCacher) GetBean(tableName string, id int64) interface{} { - m.mutex.Lock() - defer m.mutex.Unlock() - if _, ok := m.idIndex[tableName]; !ok { - m.idIndex[tableName] = make(map[interface{}]*list.Element) - } - tid := genId(tableName, id) - if v, err := m.store.Get(tid); err == nil { - if el, ok := m.idIndex[tableName][id]; ok { - lastTime := el.Value.(*idNode).lastVisit - // if expired, remove the node and return nil - if time.Now().Sub(lastTime) > m.Expired { - m.delBean(tableName, id) - //m.clearIds(tableName) - return nil - } - m.idList.MoveToBack(el) - el.Value.(*idNode).lastVisit = time.Now() - } else { - el = m.idList.PushBack(newIdNode(tableName, id)) - m.idIndex[tableName][id] = el - } - return v - } else { - // store bean is not exist, then remove memory's index - m.delBean(tableName, id) - //m.clearIds(tableName) - return nil - } + m.mutex.Lock() + defer m.mutex.Unlock() + if _, ok := m.idIndex[tableName]; !ok { + m.idIndex[tableName] = make(map[interface{}]*list.Element) + } + tid := genId(tableName, id) + if v, err := m.store.Get(tid); err == nil { + if el, ok := m.idIndex[tableName][id]; ok { + lastTime := el.Value.(*idNode).lastVisit + // if expired, remove the node and return nil + if time.Now().Sub(lastTime) > m.Expired { + m.delBean(tableName, id) + //m.clearIds(tableName) + return nil + } + m.idList.MoveToBack(el) + el.Value.(*idNode).lastVisit = time.Now() + } else { + el = m.idList.PushBack(newIdNode(tableName, id)) + m.idIndex[tableName][id] = el + } + return v + } else { + // store bean is not exist, then remove memory's index + m.delBean(tableName, id) + //m.clearIds(tableName) + return nil + } } // Clear all sql-ids mapping on table tableName from cache func (m *LRUCacher) clearIds(tableName string) { - if tis, ok := m.sqlIndex[tableName]; ok { - for sql, v := range tis { - m.sqlList.Remove(v) - m.store.Del(sql) - } - } - m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + if tis, ok := m.sqlIndex[tableName]; ok { + for sql, v := range tis { + m.sqlList.Remove(v) + m.store.Del(sql) + } + } + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) } func (m *LRUCacher) ClearIds(tableName string) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.clearIds(tableName) + m.mutex.Lock() + defer m.mutex.Unlock() + m.clearIds(tableName) } func (m *LRUCacher) clearBeans(tableName string) { - if tis, ok := m.idIndex[tableName]; ok { - for id, v := range tis { - m.idList.Remove(v) - tid := genId(tableName, id.(int64)) - m.store.Del(tid) - } - } - m.idIndex[tableName] = make(map[interface{}]*list.Element) + if tis, ok := m.idIndex[tableName]; ok { + for id, v := range tis { + m.idList.Remove(v) + tid := genId(tableName, id.(int64)) + m.store.Del(tid) + } + } + m.idIndex[tableName] = make(map[interface{}]*list.Element) } func (m *LRUCacher) ClearBeans(tableName string) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.clearBeans(tableName) + m.mutex.Lock() + defer m.mutex.Unlock() + m.clearBeans(tableName) } func (m *LRUCacher) PutIds(tableName, sql string, ids interface{}) { - m.mutex.Lock() - defer m.mutex.Unlock() - if _, ok := m.sqlIndex[tableName]; !ok { - m.sqlIndex[tableName] = make(map[interface{}]*list.Element) - } - if el, ok := m.sqlIndex[tableName][sql]; !ok { - el = m.sqlList.PushBack(newSqlNode(tableName, sql)) - m.sqlIndex[tableName][sql] = el - } else { - el.Value.(*sqlNode).lastVisit = time.Now() - } - m.store.Put(sql, ids) - if m.sqlList.Len() > m.Max { - e := m.sqlList.Front() - node := e.Value.(*sqlNode) - m.delIds(node.tbName, node.sql) - } + m.mutex.Lock() + defer m.mutex.Unlock() + if _, ok := m.sqlIndex[tableName]; !ok { + m.sqlIndex[tableName] = make(map[interface{}]*list.Element) + } + if el, ok := m.sqlIndex[tableName][sql]; !ok { + el = m.sqlList.PushBack(newSqlNode(tableName, sql)) + m.sqlIndex[tableName][sql] = el + } else { + el.Value.(*sqlNode).lastVisit = time.Now() + } + m.store.Put(sql, ids) + if m.sqlList.Len() > m.Max { + e := m.sqlList.Front() + node := e.Value.(*sqlNode) + m.delIds(node.tbName, node.sql) + } } func (m *LRUCacher) PutBean(tableName string, id int64, obj interface{}) { - m.mutex.Lock() - defer m.mutex.Unlock() - var el *list.Element - var ok bool + m.mutex.Lock() + defer m.mutex.Unlock() + var el *list.Element + var ok bool - if el, ok = m.idIndex[tableName][id]; !ok { - el = m.idList.PushBack(newIdNode(tableName, id)) - m.idIndex[tableName][id] = el - } else { - el.Value.(*idNode).lastVisit = time.Now() - } + if el, ok = m.idIndex[tableName][id]; !ok { + el = m.idList.PushBack(newIdNode(tableName, id)) + m.idIndex[tableName][id] = el + } else { + el.Value.(*idNode).lastVisit = time.Now() + } - m.store.Put(genId(tableName, id), obj) - if m.idList.Len() > m.Max { - e := m.idList.Front() - node := e.Value.(*idNode) - m.delBean(node.tbName, node.id) - } + m.store.Put(genId(tableName, id), obj) + if m.idList.Len() > m.Max { + e := m.idList.Front() + node := e.Value.(*idNode) + m.delBean(node.tbName, node.id) + } } func (m *LRUCacher) delIds(tableName, sql string) { - if _, ok := m.sqlIndex[tableName]; ok { - if el, ok := m.sqlIndex[tableName][sql]; ok { - delete(m.sqlIndex[tableName], sql) - m.sqlList.Remove(el) - } - } - m.store.Del(sql) + if _, ok := m.sqlIndex[tableName]; ok { + if el, ok := m.sqlIndex[tableName][sql]; ok { + delete(m.sqlIndex[tableName], sql) + m.sqlList.Remove(el) + } + } + m.store.Del(sql) } func (m *LRUCacher) DelIds(tableName, sql string) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.delIds(tableName, sql) + m.mutex.Lock() + defer m.mutex.Unlock() + m.delIds(tableName, sql) } func (m *LRUCacher) delBean(tableName string, id int64) { - tid := genId(tableName, id) - if el, ok := m.idIndex[tableName][id]; ok { - delete(m.idIndex[tableName], id) - m.idList.Remove(el) - m.clearIds(tableName) - } - m.store.Del(tid) + tid := genId(tableName, id) + if el, ok := m.idIndex[tableName][id]; ok { + delete(m.idIndex[tableName], id) + m.idList.Remove(el) + m.clearIds(tableName) + } + m.store.Del(tid) } func (m *LRUCacher) DelBean(tableName string, id int64) { - m.mutex.Lock() - defer m.mutex.Unlock() - m.delBean(tableName, id) + m.mutex.Lock() + defer m.mutex.Unlock() + m.delBean(tableName, id) } func encodeIds(ids []int64) (s string) { - s = "[" - for _, id := range ids { - s += fmt.Sprintf("%v,", id) - } - s = s[:len(s)-1] + "]" - return + s = "[" + for _, id := range ids { + s += fmt.Sprintf("%v,", id) + } + s = s[:len(s)-1] + "]" + return } func decodeIds(s string) []int64 { - res := make([]int64, 0) - if len(s) >= 2 { - ss := strings.Split(s[1:len(s)-1], ",") - for _, s := range ss { - i, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return res - } - res = append(res, i) - } - } - return res + res := make([]int64, 0) + if len(s) >= 2 { + ss := strings.Split(s[1:len(s)-1], ",") + for _, s := range ss { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return res + } + res = append(res, i) + } + } + return res } func getCacheSql(m Cacher, tableName, sql string, args interface{}) ([]int64, error) { - bytes := m.GetIds(tableName, genSqlKey(sql, args)) - if bytes == nil { - return nil, errors.New("Not Exist") - } - objs := decodeIds(bytes.(string)) - return objs, nil + bytes := m.GetIds(tableName, genSqlKey(sql, args)) + if bytes == nil { + return nil, errors.New("Not Exist") + } + objs := decodeIds(bytes.(string)) + return objs, nil } func putCacheSql(m Cacher, ids []int64, tableName, sql string, args interface{}) error { - bytes := encodeIds(ids) - m.PutIds(tableName, genSqlKey(sql, args), bytes) - return nil + bytes := encodeIds(ids) + m.PutIds(tableName, genSqlKey(sql, args), bytes) + return nil } func genSqlKey(sql string, args interface{}) string { - return fmt.Sprintf("%v-%v", sql, args) + return fmt.Sprintf("%v-%v", sql, args) } func genId(prefix string, id int64) string { - return fmt.Sprintf("%v-%v", prefix, id) + return fmt.Sprintf("%v-%v", prefix, id) } diff --git a/engine.go b/engine.go index 33ddaf66..56a286cb 100644 --- a/engine.go +++ b/engine.go @@ -1,43 +1,43 @@ package xorm import ( - "bufio" - "bytes" - "database/sql" - "errors" - "fmt" - "io" - "os" - "reflect" - "strconv" - "strings" - "sync" + "bufio" + "bytes" + "database/sql" + "errors" + "fmt" + "io" + "os" + "reflect" + "strconv" + "strings" + "sync" ) const ( - POSTGRES = "postgres" - SQLITE = "sqlite3" - MYSQL = "mysql" - MYMYSQL = "mymysql" + POSTGRES = "postgres" + SQLITE = "sqlite3" + MYSQL = "mysql" + MYMYSQL = "mymysql" ) // a dialect is a driver's wrapper type dialect interface { - Init(DriverName, DataSourceName string) error - SqlType(t *Column) string - SupportInsertMany() bool - QuoteStr() string - AutoIncrStr() string - SupportEngine() bool - SupportCharset() bool - IndexOnTable() bool - IndexCheckSql(tableName, idxName string) (string, []interface{}) - TableCheckSql(tableName string) (string, []interface{}) - ColumnCheckSql(tableName, colName string) (string, []interface{}) + Init(DriverName, DataSourceName string) error + SqlType(t *Column) string + SupportInsertMany() bool + QuoteStr() string + AutoIncrStr() string + SupportEngine() bool + SupportCharset() bool + IndexOnTable() bool + IndexCheckSql(tableName, idxName string) (string, []interface{}) + TableCheckSql(tableName string) (string, []interface{}) + ColumnCheckSql(tableName, colName string) (string, []interface{}) - GetColumns(tableName string) ([]string, map[string]*Column, error) - GetTables() ([]*Table, error) - GetIndexes(tableName string) (map[string]*Index, error) + GetColumns(tableName string) ([]string, map[string]*Column, error) + GetTables() ([]*Table, error) + GetIndexes(tableName string) (map[string]*Index, error) } type PK []interface{} @@ -45,36 +45,36 @@ type PK []interface{} // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { - columnMapper IMapper - tableMapper IMapper - TagIdentifier string - DriverName string - DataSourceName string - dialect dialect - Tables map[reflect.Type]*Table - mutex *sync.Mutex - ShowSQL bool - ShowErr bool - ShowDebug bool - ShowWarn bool - Pool IConnectPool - Filters []Filter - Logger io.Writer - Cacher Cacher - UseCache bool + columnMapper IMapper + tableMapper IMapper + TagIdentifier string + DriverName string + DataSourceName string + dialect dialect + Tables map[reflect.Type]*Table + mutex *sync.Mutex + ShowSQL bool + ShowErr bool + ShowDebug bool + ShowWarn bool + Pool IConnectPool + Filters []Filter + Logger io.Writer + Cacher Cacher + UseCache bool } func (engine *Engine) SetMapper(mapper IMapper) { - engine.SetTableMapper(mapper) - engine.SetColumnMapper(mapper) + engine.SetTableMapper(mapper) + engine.SetColumnMapper(mapper) } func (engine *Engine) SetTableMapper(mapper IMapper) { - engine.tableMapper = mapper + engine.tableMapper = mapper } func (engine *Engine) SetColumnMapper(mapper IMapper) { - engine.columnMapper = mapper + engine.columnMapper = mapper } // If engine's database support batch insert records like @@ -82,122 +82,122 @@ func (engine *Engine) SetColumnMapper(mapper IMapper) { // When the return is ture, then engine.Insert(&users) will // generate batch sql and exeute. func (engine *Engine) SupportInsertMany() bool { - return engine.dialect.SupportInsertMany() + return engine.dialect.SupportInsertMany() } // Engine's database use which charactor as quote. // mysql, sqlite use ` and postgres use " func (engine *Engine) QuoteStr() string { - return engine.dialect.QuoteStr() + return engine.dialect.QuoteStr() } // Use QuoteStr quote the string sql func (engine *Engine) Quote(sql string) string { - return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() + return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() } // A simple wrapper to dialect's SqlType method func (engine *Engine) SqlType(c *Column) string { - return engine.dialect.SqlType(c) + return engine.dialect.SqlType(c) } // Database's autoincrement statement func (engine *Engine) AutoIncrStr() string { - return engine.dialect.AutoIncrStr() + return engine.dialect.AutoIncrStr() } // Set engine's pool, the pool default is Go's standard library's connection pool. func (engine *Engine) SetPool(pool IConnectPool) error { - engine.Pool = pool - return engine.Pool.Init(engine) + engine.Pool = pool + return engine.Pool.Init(engine) } // SetMaxConns is only available for go 1.2+ func (engine *Engine) SetMaxConns(conns int) { - engine.Pool.SetMaxConns(conns) + engine.Pool.SetMaxConns(conns) } // SetMaxIdleConns func (engine *Engine) SetMaxIdleConns(conns int) { - engine.Pool.SetMaxIdleConns(conns) + engine.Pool.SetMaxIdleConns(conns) } // SetDefaltCacher set the default cacher. Xorm's default not enable cacher. func (engine *Engine) SetDefaultCacher(cacher Cacher) { - if cacher == nil { - engine.UseCache = false - } else { - engine.UseCache = true - engine.Cacher = cacher - } + if cacher == nil { + engine.UseCache = false + } else { + engine.UseCache = true + engine.Cacher = cacher + } } // If you has set default cacher, and you want temporilly stop use cache, // you can use NoCache() func (engine *Engine) NoCache() *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.NoCache() + session := engine.NewSession() + session.IsAutoClose = true + return session.NoCache() } // Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { - t := rType(bean) - engine.autoMapType(t) - engine.Tables[t].Cacher = cacher + t := rType(bean) + engine.autoMapType(t) + engine.Tables[t].Cacher = cacher } // OpenDB provides a interface to operate database directly. func (engine *Engine) OpenDB() (*sql.DB, error) { - return sql.Open(engine.DriverName, engine.DataSourceName) + return sql.Open(engine.DriverName, engine.DataSourceName) } // New a session func (engine *Engine) NewSession() *Session { - session := &Session{Engine: engine} - session.Init() - return session + session := &Session{Engine: engine} + session.Init() + return session } // Close the engine func (engine *Engine) Close() error { - return engine.Pool.Close(engine) + return engine.Pool.Close(engine) } // Ping tests if database is alive. func (engine *Engine) Ping() error { - session := engine.NewSession() - defer session.Close() - engine.LogSQL("PING DATABASE", engine.DriverName) - return session.Ping() + session := engine.NewSession() + defer session.Close() + engine.LogSQL("PING DATABASE", engine.DriverName) + return session.Ping() } // logging sql func (engine *Engine) LogSQL(contents ...interface{}) { - if engine.ShowSQL { - io.WriteString(engine.Logger, fmt.Sprintln(contents...)) - } + if engine.ShowSQL { + io.WriteString(engine.Logger, fmt.Sprintln(contents...)) + } } // logging error func (engine *Engine) LogError(contents ...interface{}) { - if engine.ShowErr { - io.WriteString(engine.Logger, fmt.Sprintln(contents...)) - } + if engine.ShowErr { + io.WriteString(engine.Logger, fmt.Sprintln(contents...)) + } } // logging debug func (engine *Engine) LogDebug(contents ...interface{}) { - if engine.ShowDebug { - io.WriteString(engine.Logger, fmt.Sprintln(contents...)) - } + if engine.ShowDebug { + io.WriteString(engine.Logger, fmt.Sprintln(contents...)) + } } // logging warn func (engine *Engine) LogWarn(contents ...interface{}) { - if engine.ShowWarn { - io.WriteString(engine.Logger, fmt.Sprintln(contents...)) - } + if engine.ShowWarn { + io.WriteString(engine.Logger, fmt.Sprintln(contents...)) + } } // Sql method let's you manualy write raw sql and operate @@ -208,117 +208,117 @@ func (engine *Engine) LogWarn(contents ...interface{}) { // This code will execute "select * from user" and set the records to users // func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Sql(querystring, args...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Sql(querystring, args...) } // Default if your struct has "created" or "updated" filed tag, the fields // will automatically be filled with current time when Insert or Update // invoked. Call NoAutoTime if you dont' want to fill automatically. func (engine *Engine) NoAutoTime() *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.NoAutoTime() + session := engine.NewSession() + session.IsAutoClose = true + return session.NoAutoTime() } // Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*Table, error) { - tables, err := engine.dialect.GetTables() - if err != nil { - return nil, err - } + tables, err := engine.dialect.GetTables() + if err != nil { + return nil, err + } - for _, table := range tables { - colSeq, cols, err := engine.dialect.GetColumns(table.Name) - if err != nil { - return nil, err - } - table.Columns = cols - table.ColumnsSeq = colSeq + for _, table := range tables { + colSeq, cols, err := engine.dialect.GetColumns(table.Name) + if err != nil { + return nil, err + } + table.Columns = cols + table.ColumnsSeq = colSeq - indexes, err := engine.dialect.GetIndexes(table.Name) - if err != nil { - return nil, err - } - table.Indexes = indexes + indexes, err := engine.dialect.GetIndexes(table.Name) + if err != nil { + return nil, err + } + table.Indexes = indexes - for _, index := range indexes { - for _, name := range index.Cols { - if col, ok := table.Columns[name]; ok { - col.Indexes[index.Name] = true - } else { - return nil, errors.New("Unkonwn col " + name + " in indexes") - } - } - } - } - return tables, nil + for _, index := range indexes { + for _, name := range index.Cols { + if col, ok := table.Columns[name]; ok { + col.Indexes[index.Name] = true + } else { + return nil, errors.New("Unkonwn col " + name + " in indexes") + } + } + } + } + return tables, nil } // use cascade or not func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Cascade(trueOrFalse...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Cascade(trueOrFalse...) } // Where method provide a condition query func (engine *Engine) Where(querystring string, args ...interface{}) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Where(querystring, args...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Where(querystring, args...) } // Id mehtod provoide a condition as (id) = ? func (engine *Engine) Id(id interface{}) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Id(id) + session := engine.NewSession() + session.IsAutoClose = true + return session.Id(id) } // Apply before Processor, affected bean is passed to closure arg func (engine *Engine) Before(closures func(interface{})) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Before(closures) + session := engine.NewSession() + session.IsAutoClose = true + return session.Before(closures) } // Apply after insert Processor, affected bean is passed to closure arg func (engine *Engine) After(closures func(interface{})) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.After(closures) + session := engine.NewSession() + session.IsAutoClose = true + return session.After(closures) } // set charset when create table, only support mysql now func (engine *Engine) Charset(charset string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Charset(charset) + session := engine.NewSession() + session.IsAutoClose = true + return session.Charset(charset) } // set store engine when create table, only support mysql now func (engine *Engine) StoreEngine(storeEngine string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.StoreEngine(storeEngine) + session := engine.NewSession() + session.IsAutoClose = true + return session.StoreEngine(storeEngine) } // use for distinct columns. Caution: when you are using cache, // distinct will not be cached because cache system need id, // but distinct will not provide id func (engine *Engine) Distinct(columns ...string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Distinct(columns...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Distinct(columns...) } // only use the paramters as select or update columns func (engine *Engine) Cols(columns ...string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Cols(columns...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Cols(columns...) } // Xorm automatically retrieve condition according struct, but @@ -327,45 +327,45 @@ func (engine *Engine) Cols(columns ...string) *Session { // If no paramters, it will use all the bool field of struct, or // it will use paramters's columns func (engine *Engine) UseBool(columns ...string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.UseBool(columns...) + session := engine.NewSession() + session.IsAutoClose = true + return session.UseBool(columns...) } // Only not use the paramters as select or update columns func (engine *Engine) Omit(columns ...string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Omit(columns...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Omit(columns...) } // This method will generate "column IN (?, ?)" func (engine *Engine) In(column string, args ...interface{}) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.In(column, args...) + session := engine.NewSession() + session.IsAutoClose = true + return session.In(column, args...) } // Temporarily change the Get, Find, Update's table func (engine *Engine) Table(tableNameOrBean interface{}) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Table(tableNameOrBean) + session := engine.NewSession() + session.IsAutoClose = true + return session.Table(tableNameOrBean) } // This method will generate "LIMIT start, limit" func (engine *Engine) Limit(limit int, start ...int) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Limit(limit, start...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Limit(limit, start...) } // Method Desc will generate "ORDER BY column1 DESC, column2 DESC" // This will func (engine *Engine) Desc(colNames ...string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Desc(colNames...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Desc(colNames...) } // Method Asc will generate "ORDER BY column1 DESC, column2 Asc" @@ -375,507 +375,507 @@ func (engine *Engine) Desc(colNames ...string) *Session { // // SELECT * FROM user ORDER BY name DESC, age ASC // func (engine *Engine) Asc(colNames ...string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Asc(colNames...) + session := engine.NewSession() + session.IsAutoClose = true + return session.Asc(colNames...) } // Method OrderBy will generate "ORDER BY order" func (engine *Engine) OrderBy(order string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.OrderBy(order) + session := engine.NewSession() + session.IsAutoClose = true + return session.OrderBy(order) } // The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (engine *Engine) Join(join_operator, tablename, condition string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Join(join_operator, tablename, condition) + session := engine.NewSession() + session.IsAutoClose = true + return session.Join(join_operator, tablename, condition) } // Generate Group By statement func (engine *Engine) GroupBy(keys string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.GroupBy(keys) + session := engine.NewSession() + session.IsAutoClose = true + return session.GroupBy(keys) } // Generate Having statement func (engine *Engine) Having(conditions string) *Session { - session := engine.NewSession() - session.IsAutoClose = true - return session.Having(conditions) + session := engine.NewSession() + session.IsAutoClose = true + return session.Having(conditions) } func (engine *Engine) autoMapType(t reflect.Type) *Table { - engine.mutex.Lock() - defer engine.mutex.Unlock() - table, ok := engine.Tables[t] - if !ok { - table = engine.mapType(t) - engine.Tables[t] = table - } - return table + engine.mutex.Lock() + defer engine.mutex.Unlock() + table, ok := engine.Tables[t] + if !ok { + table = engine.mapType(t) + engine.Tables[t] = table + } + return table } func (engine *Engine) autoMap(bean interface{}) *Table { - t := rType(bean) - return engine.autoMapType(t) + t := rType(bean) + return engine.autoMapType(t) } func (engine *Engine) newTable() *Table { - table := &Table{} - table.Indexes = make(map[string]*Index) - table.Columns = make(map[string]*Column) - table.ColumnsSeq = make([]string, 0) - table.Created = make(map[string]bool) - table.Cacher = engine.Cacher - return table + table := &Table{} + table.Indexes = make(map[string]*Index) + table.Columns = make(map[string]*Column) + table.ColumnsSeq = make([]string, 0) + table.Created = make(map[string]bool) + table.Cacher = engine.Cacher + return table } func (engine *Engine) mapType(t reflect.Type) *Table { - table := engine.newTable() - table.Name = engine.tableMapper.Obj2Table(t.Name()) - table.Type = t + table := engine.newTable() + table.Name = engine.tableMapper.Obj2Table(t.Name()) + table.Type = t - var idFieldColName string + var idFieldColName string - for i := 0; i < t.NumField(); i++ { - tag := t.Field(i).Tag - ormTagStr := tag.Get(engine.TagIdentifier) - var col *Column - fieldType := t.Field(i).Type + for i := 0; i < t.NumField(); i++ { + tag := t.Field(i).Tag + ormTagStr := tag.Get(engine.TagIdentifier) + var col *Column + fieldType := t.Field(i).Type - if ormTagStr != "" { - col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: TWOSIDES, Indexes: make(map[string]bool)} - tags := strings.Split(ormTagStr, " ") + if ormTagStr != "" { + col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, + IsAutoIncrement: false, MapType: TWOSIDES, Indexes: make(map[string]bool)} + tags := strings.Split(ormTagStr, " ") - if len(tags) > 0 { - if tags[0] == "-" { - continue - } - if (strings.ToUpper(tags[0]) == "EXTENDS") && - (fieldType.Kind() == reflect.Struct) { - parentTable := engine.mapType(fieldType) - for name, col := range parentTable.Columns { - col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) - table.Columns[name] = col - table.ColumnsSeq = append(table.ColumnsSeq, name) - } + if len(tags) > 0 { + if tags[0] == "-" { + continue + } + if (strings.ToUpper(tags[0]) == "EXTENDS") && + (fieldType.Kind() == reflect.Struct) { + parentTable := engine.mapType(fieldType) + for name, col := range parentTable.Columns { + col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) + table.Columns[name] = col + table.ColumnsSeq = append(table.ColumnsSeq, name) + } - table.PrimaryKey = parentTable.PrimaryKey - continue - } - var indexType int - var indexName string - for j, key := range tags { - k := strings.ToUpper(key) - switch { - case k == "<-": - col.MapType = ONLYFROMDB - case k == "->": - col.MapType = ONLYTODB - case k == "PK": - col.IsPrimaryKey = true - col.Nullable = false - case k == "NULL": - col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT") - case k == "AUTOINCR": - col.IsAutoIncrement = true - case k == "DEFAULT": - col.Default = tags[j+1] - case k == "CREATED": - col.IsCreated = true - case k == "VERSION": - col.IsVersion = true - col.Default = "1" - case k == "UPDATED": - col.IsUpdated = true - case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): - indexType = IndexType - indexName = k[len("INDEX")+1 : len(k)-1] - case k == "INDEX": - indexType = IndexType - case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): - indexName = k[len("UNIQUE")+1 : len(k)-1] - indexType = UniqueType - case k == "UNIQUE": - indexType = UniqueType - case k == "NOTNULL": - col.Nullable = false - case k == "NOT": - default: - if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") { - if key != col.Default { - col.Name = key[1 : len(key)-1] - } - } else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { - fs := strings.Split(k, "(") - if _, ok := sqlTypes[fs[0]]; !ok { - continue - } - col.SQLType = SQLType{fs[0], 0, 0} - fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") - if len(fs2) == 2 { - col.Length, _ = strconv.Atoi(fs2[0]) - col.Length2, _ = strconv.Atoi(fs2[1]) - } else if len(fs2) == 1 { - col.Length, _ = strconv.Atoi(fs2[0]) - } - } else { - if _, ok := sqlTypes[k]; ok { - col.SQLType = SQLType{k, 0, 0} - } else if key != col.Default { - col.Name = key - } - } - engine.SqlType(col) - } - } - if col.SQLType.Name == "" { - col.SQLType = Type2SQLType(fieldType) - } - if col.Length == 0 { - col.Length = col.SQLType.DefaultLength - } - if col.Length2 == 0 { - col.Length2 = col.SQLType.DefaultLength2 - } - if col.Name == "" { - col.Name = engine.columnMapper.Obj2Table(t.Field(i).Name) - } - if indexType == IndexType { - if indexName == "" { - indexName = col.Name - } - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = true - } else { - index := NewIndex(indexName, IndexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = true - } - } else if indexType == UniqueType { - if indexName == "" { - indexName = col.Name - } - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = true - } else { - index := NewIndex(indexName, UniqueType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = true - } - } - } - } else { - sqlType := Type2SQLType(fieldType) - col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, - sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, - TWOSIDES, false, false, false, false} - } - if col.IsAutoIncrement { - col.Nullable = false - } + table.PrimaryKey = parentTable.PrimaryKey + continue + } + var indexType int + var indexName string + for j, key := range tags { + k := strings.ToUpper(key) + switch { + case k == "<-": + col.MapType = ONLYFROMDB + case k == "->": + col.MapType = ONLYTODB + case k == "PK": + col.IsPrimaryKey = true + col.Nullable = false + case k == "NULL": + col.Nullable = (strings.ToUpper(tags[j-1]) != "NOT") + case k == "AUTOINCR": + col.IsAutoIncrement = true + case k == "DEFAULT": + col.Default = tags[j+1] + case k == "CREATED": + col.IsCreated = true + case k == "VERSION": + col.IsVersion = true + col.Default = "1" + case k == "UPDATED": + col.IsUpdated = true + case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): + indexType = IndexType + indexName = k[len("INDEX")+1 : len(k)-1] + case k == "INDEX": + indexType = IndexType + case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): + indexName = k[len("UNIQUE")+1 : len(k)-1] + indexType = UniqueType + case k == "UNIQUE": + indexType = UniqueType + case k == "NOTNULL": + col.Nullable = false + case k == "NOT": + default: + if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") { + if key != col.Default { + col.Name = key[1 : len(key)-1] + } + } else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { + fs := strings.Split(k, "(") + if _, ok := sqlTypes[fs[0]]; !ok { + continue + } + col.SQLType = SQLType{fs[0], 0, 0} + fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") + if len(fs2) == 2 { + col.Length, _ = strconv.Atoi(fs2[0]) + col.Length2, _ = strconv.Atoi(fs2[1]) + } else if len(fs2) == 1 { + col.Length, _ = strconv.Atoi(fs2[0]) + } + } else { + if _, ok := sqlTypes[k]; ok { + col.SQLType = SQLType{k, 0, 0} + } else if key != col.Default { + col.Name = key + } + } + engine.SqlType(col) + } + } + if col.SQLType.Name == "" { + col.SQLType = Type2SQLType(fieldType) + } + if col.Length == 0 { + col.Length = col.SQLType.DefaultLength + } + if col.Length2 == 0 { + col.Length2 = col.SQLType.DefaultLength2 + } + if col.Name == "" { + col.Name = engine.columnMapper.Obj2Table(t.Field(i).Name) + } + if indexType == IndexType { + if indexName == "" { + indexName = col.Name + } + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = true + } else { + index := NewIndex(indexName, IndexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = true + } + } else if indexType == UniqueType { + if indexName == "" { + indexName = col.Name + } + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = true + } else { + index := NewIndex(indexName, UniqueType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = true + } + } + } + } else { + sqlType := Type2SQLType(fieldType) + col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, + sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, + TWOSIDES, false, false, false, false} + } + if col.IsAutoIncrement { + col.Nullable = false + } - table.AddColumn(col) + table.AddColumn(col) - if col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id") { - idFieldColName = col.Name - } - } + if col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id") { + idFieldColName = col.Name + } + } - if idFieldColName != "" && table.PrimaryKey == "" { - col := table.Columns[idFieldColName] - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - table.PrimaryKey = col.Name - } + if idFieldColName != "" && table.PrimaryKey == "" { + col := table.Columns[idFieldColName] + col.IsPrimaryKey = true + col.IsAutoIncrement = true + col.Nullable = false + table.PrimaryKey = col.Name + } - return table + return table } // Map a struct to a table func (engine *Engine) mapping(beans ...interface{}) (e error) { - engine.mutex.Lock() - defer engine.mutex.Unlock() - for _, bean := range beans { - t := rType(bean) - engine.Tables[t] = engine.mapType(t) - } - return + engine.mutex.Lock() + defer engine.mutex.Unlock() + for _, bean := range beans { + t := rType(bean) + engine.Tables[t] = engine.mapType(t) + } + return } // If a table has any reocrd func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { - t := rType(bean) - if t.Kind() != reflect.Struct { - return false, errors.New("bean should be a struct or struct's point") - } - engine.autoMapType(t) - session := engine.NewSession() - defer session.Close() - rows, err := session.Count(bean) - return rows > 0, err + t := rType(bean) + if t.Kind() != reflect.Struct { + return false, errors.New("bean should be a struct or struct's point") + } + engine.autoMapType(t) + session := engine.NewSession() + defer session.Close() + rows, err := session.Count(bean) + return rows > 0, err } // If a table is exist func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { - t := rType(bean) - if t.Kind() != reflect.Struct { - return false, errors.New("bean should be a struct or struct's point") - } - table := engine.autoMapType(t) - session := engine.NewSession() - defer session.Close() - has, err := session.isTableExist(table.Name) - return has, err + t := rType(bean) + if t.Kind() != reflect.Struct { + return false, errors.New("bean should be a struct or struct's point") + } + table := engine.autoMapType(t) + session := engine.NewSession() + defer session.Close() + has, err := session.isTableExist(table.Name) + return has, err } // create indexes func (engine *Engine) CreateIndexes(bean interface{}) error { - session := engine.NewSession() - defer session.Close() - return session.CreateIndexes(bean) + session := engine.NewSession() + defer session.Close() + return session.CreateIndexes(bean) } // create uniques func (engine *Engine) CreateUniques(bean interface{}) error { - session := engine.NewSession() - defer session.Close() - return session.CreateUniques(bean) + session := engine.NewSession() + defer session.Close() + return session.CreateUniques(bean) } // If enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { - t := rType(bean) - if t.Kind() != reflect.Struct { - return errors.New("error params") - } - table := engine.autoMap(bean) - if table.Cacher != nil { - table.Cacher.ClearIds(table.Name) - table.Cacher.DelBean(table.Name, id) - } - return nil + t := rType(bean) + if t.Kind() != reflect.Struct { + return errors.New("error params") + } + table := engine.autoMap(bean) + if table.Cacher != nil { + table.Cacher.ClearIds(table.Name) + table.Cacher.DelBean(table.Name, id) + } + return nil } // If enabled cache, clear some tables' cache func (engine *Engine) ClearCache(beans ...interface{}) error { - for _, bean := range beans { - t := rType(bean) - if t.Kind() != reflect.Struct { - return errors.New("error params") - } - table := engine.autoMap(bean) - if table.Cacher != nil { - table.Cacher.ClearIds(table.Name) - table.Cacher.ClearBeans(table.Name) - } - } - return nil + for _, bean := range beans { + t := rType(bean) + if t.Kind() != reflect.Struct { + return errors.New("error params") + } + table := engine.autoMap(bean) + if table.Cacher != nil { + table.Cacher.ClearIds(table.Name) + table.Cacher.ClearBeans(table.Name) + } + } + return nil } // Sync the new struct changes to database, this method will automatically add // table, column, index, unique. but will not delete or change anything. // If you change some field, you should change the database manually. func (engine *Engine) Sync(beans ...interface{}) error { - for _, bean := range beans { - table := engine.autoMap(bean) + for _, bean := range beans { + table := engine.autoMap(bean) - s := engine.NewSession() - defer s.Close() - isExist, err := s.Table(bean).isTableExist(table.Name) - if err != nil { - return err - } - if !isExist { - err = engine.CreateTables(bean) - if err != nil { - return err - } - } - /*isEmpty, err := engine.IsEmptyTable(bean) - if err != nil { - return err - }*/ - var isEmpty bool = false - if isEmpty { - err = engine.DropTables(bean) - if err != nil { - return err - } - err = engine.CreateTables(bean) - if err != nil { - return err - } - } else { - for _, col := range table.Columns { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - isExist, err := session.isColumnExist(table.Name, col.Name) - if err != nil { - return err - } - if !isExist { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addColumn(col.Name) - if err != nil { - return err - } - } - } + s := engine.NewSession() + defer s.Close() + isExist, err := s.Table(bean).isTableExist(table.Name) + if err != nil { + return err + } + if !isExist { + err = engine.CreateTables(bean) + if err != nil { + return err + } + } + /*isEmpty, err := engine.IsEmptyTable(bean) + if err != nil { + return err + }*/ + var isEmpty bool = false + if isEmpty { + err = engine.DropTables(bean) + if err != nil { + return err + } + err = engine.CreateTables(bean) + if err != nil { + return err + } + } else { + for _, col := range table.Columns { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + isExist, err := session.isColumnExist(table.Name, col.Name) + if err != nil { + return err + } + if !isExist { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addColumn(col.Name) + if err != nil { + return err + } + } + } - for name, index := range table.Indexes { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - if index.Type == UniqueType { - //isExist, err := session.isIndexExist(table.Name, name, true) - isExist, err := session.isIndexExist2(table.Name, index.Cols, true) - if err != nil { - return err - } - if !isExist { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addUnique(table.Name, name) - if err != nil { - return err - } - } - } else if index.Type == IndexType { - isExist, err := session.isIndexExist2(table.Name, index.Cols, false) - if err != nil { - return err - } - if !isExist { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addIndex(table.Name, name) - if err != nil { - return err - } - } - } else { - return errors.New("unknow index type") - } - } - } - } - return nil + for name, index := range table.Indexes { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + if index.Type == UniqueType { + //isExist, err := session.isIndexExist(table.Name, name, true) + isExist, err := session.isIndexExist2(table.Name, index.Cols, true) + if err != nil { + return err + } + if !isExist { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addUnique(table.Name, name) + if err != nil { + return err + } + } + } else if index.Type == IndexType { + isExist, err := session.isIndexExist2(table.Name, index.Cols, false) + if err != nil { + return err + } + if !isExist { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addIndex(table.Name, name) + if err != nil { + return err + } + } + } else { + return errors.New("unknow index type") + } + } + } + } + return nil } func (engine *Engine) unMap(beans ...interface{}) (e error) { - engine.mutex.Lock() - defer engine.mutex.Unlock() - for _, bean := range beans { - t := rType(bean) - if _, ok := engine.Tables[t]; ok { - delete(engine.Tables, t) - } - } - return + engine.mutex.Lock() + defer engine.mutex.Unlock() + for _, bean := range beans { + t := rType(bean) + if _, ok := engine.Tables[t]; ok { + delete(engine.Tables, t) + } + } + return } // Drop all mapped table func (engine *Engine) dropAll() error { - session := engine.NewSession() - defer session.Close() + session := engine.NewSession() + defer session.Close() - err := session.Begin() - if err != nil { - return err - } - err = session.dropAll() - if err != nil { - session.Rollback() - return err - } - return session.Commit() + err := session.Begin() + if err != nil { + return err + } + err = session.dropAll() + if err != nil { + session.Rollback() + return err + } + return session.Commit() } // CreateTables create tabls according bean func (engine *Engine) CreateTables(beans ...interface{}) error { - session := engine.NewSession() - err := session.Begin() - defer session.Close() - if err != nil { - return err - } + session := engine.NewSession() + err := session.Begin() + defer session.Close() + if err != nil { + return err + } - for _, bean := range beans { - err = session.CreateTable(bean) - if err != nil { - session.Rollback() - return err - } - } - return session.Commit() + for _, bean := range beans { + err = session.CreateTable(bean) + if err != nil { + session.Rollback() + return err + } + } + return session.Commit() } func (engine *Engine) DropTables(beans ...interface{}) error { - session := engine.NewSession() - err := session.Begin() - defer session.Close() - if err != nil { - return err - } + session := engine.NewSession() + err := session.Begin() + defer session.Close() + if err != nil { + return err + } - for _, bean := range beans { - err = session.DropTable(bean) - if err != nil { - session.Rollback() - return err - } - } - return session.Commit() + for _, bean := range beans { + err = session.DropTable(bean) + if err != nil { + session.Rollback() + return err + } + } + return session.Commit() } func (engine *Engine) createAll() error { - session := engine.NewSession() - defer session.Close() - return session.createAll() + session := engine.NewSession() + defer session.Close() + return session.createAll() } // Exec raw sql func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { - session := engine.NewSession() - defer session.Close() - return session.Exec(sql, args...) + session := engine.NewSession() + defer session.Close() + return session.Exec(sql, args...) } // Exec a raw sql and return records as []map[string][]byte func (engine *Engine) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - session := engine.NewSession() - defer session.Close() - return session.Query(sql, paramStr...) + session := engine.NewSession() + defer session.Close() + return session.Query(sql, paramStr...) } // Insert one or more records func (engine *Engine) Insert(beans ...interface{}) (int64, error) { - session := engine.NewSession() - defer session.Close() - return session.Insert(beans...) + session := engine.NewSession() + defer session.Close() + return session.Insert(beans...) } // Insert only one record func (engine *Engine) InsertOne(bean interface{}) (int64, error) { - session := engine.NewSession() - defer session.Close() - return session.InsertOne(bean) + session := engine.NewSession() + defer session.Close() + return session.InsertOne(bean) } // Update records, bean's non-empty fields are updated contents, @@ -885,94 +885,94 @@ func (engine *Engine) InsertOne(bean interface{}) (int64, error) { // You should call UseBool if you have bool to use. // 2.float32 & float64 may be not inexact as conditions func (engine *Engine) Update(bean interface{}, condiBeans ...interface{}) (int64, error) { - session := engine.NewSession() - defer session.Close() - return session.Update(bean, condiBeans...) + session := engine.NewSession() + defer session.Close() + return session.Update(bean, condiBeans...) } // Delete records, bean's non-empty fields are conditions func (engine *Engine) Delete(bean interface{}) (int64, error) { - session := engine.NewSession() - defer session.Close() - return session.Delete(bean) + session := engine.NewSession() + defer session.Close() + return session.Delete(bean) } // Get retrieve one record from table, bean's non-empty fields // are conditions func (engine *Engine) Get(bean interface{}) (bool, error) { - session := engine.NewSession() - defer session.Close() - return session.Get(bean) + session := engine.NewSession() + defer session.Close() + return session.Get(bean) } // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { - session := engine.NewSession() - defer session.Close() - return session.Find(beans, condiBeans...) + session := engine.NewSession() + defer session.Close() + return session.Find(beans, condiBeans...) } // Iterate record by record handle records from table, bean's non-empty fields // are conditions. func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { - session := engine.NewSession() - defer session.Close() - return session.Iterate(bean, fun) + session := engine.NewSession() + defer session.Close() + return session.Iterate(bean, fun) } // Count counts the records. bean's non-empty fields // are conditions. func (engine *Engine) Count(bean interface{}) (int64, error) { - session := engine.NewSession() - defer session.Close() - return session.Count(bean) + session := engine.NewSession() + defer session.Close() + return session.Count(bean) } // Import SQL DDL file func (engine *Engine) Import(ddlPath string) ([]sql.Result, error) { - file, err := os.Open(ddlPath) - if err != nil { - return nil, err - } - defer file.Close() + file, err := os.Open(ddlPath) + if err != nil { + return nil, err + } + defer file.Close() - var results []sql.Result - var lastError error - scanner := bufio.NewScanner(file) + var results []sql.Result + var lastError error + scanner := bufio.NewScanner(file) - semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := bytes.IndexByte(data, ';'); i >= 0 { - return i + 1, data[0:i], nil - } - // If we're at EOF, we have a final, non-terminated line. Return it. - if atEOF { - return len(data), data, nil - } - // Request more data. - return 0, nil, nil - } + semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, ';'); i >= 0 { + return i + 1, data[0:i], nil + } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + } - scanner.Split(semiColSpliter) + scanner.Split(semiColSpliter) - session := engine.NewSession() - session.IsAutoClose = false - for scanner.Scan() { - query := scanner.Text() - query = strings.Trim(query, " \t") - if len(query) > 0 { - result, err := session.Exec(query) - results = append(results, result) - if err != nil { - lastError = err - } - } - } - session.Close() - return results, lastError + session := engine.NewSession() + session.IsAutoClose = false + for scanner.Scan() { + query := scanner.Text() + query = strings.Trim(query, " \t") + if len(query) > 0 { + result, err := session.Exec(query) + results = append(results, result) + if err != nil { + lastError = err + } + } + } + session.Close() + return results, lastError } diff --git a/error.go b/error.go index ac3ed7de..c868173f 100644 --- a/error.go +++ b/error.go @@ -1,15 +1,15 @@ package xorm import ( - "errors" + "errors" ) var ( - ErrParamsType error = errors.New("Params type error") - ErrTableNotFound error = errors.New("Not found table") - ErrUnSupportedType error = errors.New("Unsupported type error") - ErrNotExist error = errors.New("Not exist error") - ErrCacheFailed error = errors.New("Cache failed") - ErrNeedDeletedCond error = errors.New("Delete need at least one condition") - ErrNotImplemented error = errors.New("Not implemented.") + ErrParamsType error = errors.New("Params type error") + ErrTableNotFound error = errors.New("Not found table") + ErrUnSupportedType error = errors.New("Unsupported type error") + ErrNotExist error = errors.New("Not exist error") + ErrCacheFailed error = errors.New("Cache failed") + ErrNeedDeletedCond error = errors.New("Delete need at least one condition") + ErrNotImplemented error = errors.New("Not implemented.") ) diff --git a/filter.go b/filter.go index 3d576800..5fff4c0d 100644 --- a/filter.go +++ b/filter.go @@ -1,13 +1,13 @@ package xorm import ( - "fmt" - "strings" + "fmt" + "strings" ) // Filter is an interface to filter SQL type Filter interface { - Do(sql string, session *Session) string + Do(sql string, session *Session) string } // PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ... @@ -15,16 +15,16 @@ type PgSeqFilter struct { } func (s *PgSeqFilter) Do(sql string, session *Session) string { - segs := strings.Split(sql, "?") - size := len(segs) - res := "" - for i, c := range segs { - if i < size-1 { - res += c + fmt.Sprintf("$%v", i+1) - } - } - res += segs[size-1] - return res + segs := strings.Split(sql, "?") + size := len(segs) + res := "" + for i, c := range segs { + if i < size-1 { + res += c + fmt.Sprintf("$%v", i+1) + } + } + res += segs[size-1] + return res } // QuoteFilter filter SQL replace ` to database's own quote character @@ -32,7 +32,7 @@ type QuoteFilter struct { } func (s *QuoteFilter) Do(sql string, session *Session) string { - return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1) + return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1) } // IdFilter filter SQL replace (id) to primary key column name @@ -40,10 +40,10 @@ type IdFilter struct { } func (i *IdFilter) Do(sql string, session *Session) string { - if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { - sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - } - return sql + if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { + sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) + sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) + return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) + } + return sql } diff --git a/helpers.go b/helpers.go index 6b73af9e..307353c2 100644 --- a/helpers.go +++ b/helpers.go @@ -1,63 +1,63 @@ package xorm import ( - "reflect" - "strings" + "reflect" + "strings" ) func indexNoCase(s, sep string) int { - return strings.Index(strings.ToLower(s), strings.ToLower(sep)) + return strings.Index(strings.ToLower(s), strings.ToLower(sep)) } func splitNoCase(s, sep string) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.Split(s, s[idx:idx+len(sep)]) + idx := indexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.Split(s, s[idx:idx+len(sep)]) } func splitNNoCase(s, sep string, n int) []string { - idx := indexNoCase(s, sep) - if idx < 0 { - return []string{s} - } - return strings.SplitN(s, s[idx:idx+len(sep)], n) + idx := indexNoCase(s, sep) + if idx < 0 { + return []string{s} + } + return strings.SplitN(s, s[idx:idx+len(sep)], n) } func makeArray(elem string, count int) []string { - res := make([]string, count) - for i := 0; i < count; i++ { - res[i] = elem - } - return res + res := make([]string, count) + for i := 0; i < count; i++ { + res[i] = elem + } + return res } func rType(bean interface{}) reflect.Type { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - return reflect.TypeOf(sliceValue.Interface()) + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + return reflect.TypeOf(sliceValue.Interface()) } func structName(v reflect.Type) string { - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - return v.Name() + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Name() } func sliceEq(left, right []string) bool { - for _, l := range left { - var find bool - for _, r := range right { - if l == r { - find = true - break - } - } - if !find { - return false - } - } + for _, l := range left { + var find bool + for _, r := range right { + if l == r { + find = true + break + } + } + if !find { + return false + } + } - return true + return true } diff --git a/mapper.go b/mapper.go index 078f3b09..2e9c220a 100644 --- a/mapper.go +++ b/mapper.go @@ -1,13 +1,13 @@ package xorm import ( - "strings" + "strings" ) // name translation between struct, fields names and table, column names type IMapper interface { - Obj2Table(string) string - Table2Obj(string) string + Obj2Table(string) string + Table2Obj(string) string } // SameMapper implements IMapper and provides same name between struct and @@ -16,11 +16,11 @@ type SameMapper struct { } func (m SameMapper) Obj2Table(o string) string { - return o + return o } func (m SameMapper) Table2Obj(t string) string { - return t + return t } // SnakeMapper implements IMapper and provides name transaltion between @@ -29,18 +29,18 @@ type SnakeMapper struct { } func snakeCasedName(name string) string { - newstr := make([]rune, 0) - for idx, chr := range name { - if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { - if idx > 0 { - newstr = append(newstr, '_') - } - chr -= ('A' - 'a') - } - newstr = append(newstr, chr) - } + newstr := make([]rune, 0) + for idx, chr := range name { + if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { + if idx > 0 { + newstr = append(newstr, '_') + } + chr -= ('A' - 'a') + } + newstr = append(newstr, chr) + } - return string(newstr) + return string(newstr) } /*func pascal2Sql(s string) (d string) { @@ -63,69 +63,69 @@ func snakeCasedName(name string) string { }*/ func (mapper SnakeMapper) Obj2Table(name string) string { - return snakeCasedName(name) + return snakeCasedName(name) } func titleCasedName(name string) string { - newstr := make([]rune, 0) - upNextChar := true + newstr := make([]rune, 0) + upNextChar := true - name = strings.ToLower(name) + name = strings.ToLower(name) - for _, chr := range name { - switch { - case upNextChar: - upNextChar = false - if 'a' <= chr && chr <= 'z' { - chr -= ('a' - 'A') - } - case chr == '_': - upNextChar = true - continue - } + for _, chr := range name { + switch { + case upNextChar: + upNextChar = false + if 'a' <= chr && chr <= 'z' { + chr -= ('a' - 'A') + } + case chr == '_': + upNextChar = true + continue + } - newstr = append(newstr, chr) - } + newstr = append(newstr, chr) + } - return string(newstr) + return string(newstr) } func (mapper SnakeMapper) Table2Obj(name string) string { - return titleCasedName(name) + return titleCasedName(name) } // provide prefix table name support type PrefixMapper struct { - Mapper IMapper - Prefix string + Mapper IMapper + Prefix string } func (mapper PrefixMapper) Obj2Table(name string) string { - return mapper.Prefix + mapper.Mapper.Obj2Table(name) + return mapper.Prefix + mapper.Mapper.Obj2Table(name) } func (mapper PrefixMapper) Table2Obj(name string) string { - return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) + return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) } func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { - return PrefixMapper{mapper, prefix} + return PrefixMapper{mapper, prefix} } // provide suffix table name support type SuffixMapper struct { - Mapper IMapper - Suffix string + Mapper IMapper + Suffix string } func (mapper SuffixMapper) Obj2Table(name string) string { - return mapper.Suffix + mapper.Mapper.Obj2Table(name) + return mapper.Suffix + mapper.Mapper.Obj2Table(name) } func (mapper SuffixMapper) Table2Obj(name string) string { - return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):]) + return mapper.Mapper.Table2Obj(name[len(mapper.Suffix):]) } func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { - return SuffixMapper{mapper, suffix} + return SuffixMapper{mapper, suffix} } diff --git a/mymysql.go b/mymysql.go index c664be0d..8101d4c0 100644 --- a/mymysql.go +++ b/mymysql.go @@ -1,66 +1,67 @@ package xorm import ( - "errors" - "strings" - "time" + "errors" + "strings" + "time" ) type mymysql struct { - mysql - proto string - raddr string - laddr string - timeout time.Duration - db string - user string - passwd string + mysql +} + +type mymysqlParser struct { +} + +func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) { + db := &uri{dbType: MYSQL} + + pd := strings.SplitN(dataSourceName, "*", 2) + if len(pd) == 2 { + // Parse protocol part of URI + p := strings.SplitN(pd[0], ":", 2) + if len(p) != 2 { + return nil, errors.New("Wrong protocol part of URI") + } + db.proto = p[0] + options := strings.Split(p[1], ",") + db.raddr = options[0] + for _, o := range options[1:] { + kv := strings.SplitN(o, "=", 2) + var k, v string + if len(kv) == 2 { + k, v = kv[0], kv[1] + } else { + k, v = o, "true" + } + switch k { + case "laddr": + db.laddr = v + case "timeout": + to, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + db.timeout = to + default: + return nil, errors.New("Unknown option: " + k) + } + } + // Remove protocol part + pd = pd[1:] + } + // Parse database part of URI + dup := strings.SplitN(pd[0], "/", 3) + if len(dup) != 3 { + return nil, errors.New("Wrong database part of URI") + } + db.dbName = dup[0] + db.user = dup[1] + db.passwd = dup[2] + + return db, nil } func (db *mymysql) Init(drivername, uri string) error { - db.mysql.base.init(drivername, uri) - pd := strings.SplitN(uri, "*", 2) - if len(pd) == 2 { - // Parse protocol part of URI - p := strings.SplitN(pd[0], ":", 2) - if len(p) != 2 { - return errors.New("Wrong protocol part of URI") - } - db.proto = p[0] - options := strings.Split(p[1], ",") - db.raddr = options[0] - for _, o := range options[1:] { - kv := strings.SplitN(o, "=", 2) - var k, v string - if len(kv) == 2 { - k, v = kv[0], kv[1] - } else { - k, v = o, "true" - } - switch k { - case "laddr": - db.laddr = v - case "timeout": - to, err := time.ParseDuration(v) - if err != nil { - return err - } - db.timeout = to - default: - return errors.New("Unknown option: " + k) - } - } - // Remove protocol part - pd = pd[1:] - } - // Parse database part of URI - dup := strings.SplitN(pd[0], "/", 3) - if len(dup) != 3 { - return errors.New("Wrong database part of URI") - } - db.dbname = dup[0] - db.user = dup[1] - db.passwd = dup[2] - - return nil + return db.mysql.base.init(&mymysqlParser{}, drivername, uri) } diff --git a/mysql.go b/mysql.go index 17e6603c..bde0186a 100644 --- a/mysql.go +++ b/mysql.go @@ -1,311 +1,323 @@ package xorm import ( - "crypto/tls" - "database/sql" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - "time" + "crypto/tls" + "database/sql" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" ) -type base struct { - drivername string - dataSourceName string +type uri struct { + dbType string + proto string + host string + port string + dbName string + user string + passwd string + charset string + laddr string + raddr string + timeout time.Duration } -func (b *base) init(drivername, dataSourceName string) { - b.drivername, b.dataSourceName = drivername, dataSourceName +type parser interface { + parse(driverName, dataSourceName string) (*uri, error) +} + +type mysqlParser struct { +} + +func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) { + //cfg.params = make(map[string]string) + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + //tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + uri := &uri{dbType: MYSQL} + + for i, match := range matches { + switch names[i] { + case "dbname": + uri.dbName = match + } + } + return uri, nil +} + +type base struct { + parser parser + driverName string + dataSourceName string + *uri +} + +func (b *base) init(parser parser, drivername, dataSourceName string) (err error) { + b.parser = parser + b.driverName, b.dataSourceName = drivername, dataSourceName + b.uri, err = b.parser.parse(b.driverName, b.dataSourceName) + return } type mysql struct { - base - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - timeout time.Duration - tls *tls.Config - allowAllFiles bool - allowOldPasswords bool - clientFoundRows bool -} - -/*func readBool(input string) (value bool, valid bool) { - switch input { - case "1", "true", "TRUE", "True": - return true, true - case "0", "false", "FALSE", "False": - return false, true - } - - // Not a valid bool value - return -}*/ - -func (cfg *mysql) parseDSN(dsn string) (err error) { - //cfg.params = make(map[string]string) - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dsn) - //tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - for i, match := range matches { - switch names[i] { - case "dbname": - cfg.dbname = match - } - } - return + base + net string + addr string + params map[string]string + loc *time.Location + timeout time.Duration + tls *tls.Config + allowAllFiles bool + allowOldPasswords bool + clientFoundRows bool } func (db *mysql) Init(drivername, uri string) error { - db.base.init(drivername, uri) - return db.parseDSN(uri) + return db.base.init(&mysqlParser{}, drivername, uri) } func (db *mysql) SqlType(c *Column) string { - var res string - switch t := c.SQLType.Name; t { - case Bool: - res = TinyInt - case Serial: - c.IsAutoIncrement = true - c.IsPrimaryKey = true - c.Nullable = false - res = Int - case BigSerial: - c.IsAutoIncrement = true - c.IsPrimaryKey = true - c.Nullable = false - res = BigInt - case Bytea: - res = Blob - case TimeStampz: - res = Char - c.Length = 64 - default: - res = t - } + var res string + switch t := c.SQLType.Name; t { + case Bool: + res = TinyInt + case Serial: + c.IsAutoIncrement = true + c.IsPrimaryKey = true + c.Nullable = false + res = Int + case BigSerial: + c.IsAutoIncrement = true + c.IsPrimaryKey = true + c.Nullable = false + res = BigInt + case Bytea: + res = Blob + case TimeStampz: + res = Char + c.Length = 64 + default: + res = t + } - var hasLen1 bool = (c.Length > 0) - var hasLen2 bool = (c.Length2 > 0) - if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" - } else if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" - } - return res + var hasLen1 bool = (c.Length > 0) + var hasLen2 bool = (c.Length2 > 0) + if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } else if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } + return res } func (db *mysql) SupportInsertMany() bool { - return true + return true } func (db *mysql) QuoteStr() string { - return "`" + return "`" } func (db *mysql) SupportEngine() bool { - return true + return true } func (db *mysql) AutoIncrStr() string { - return "AUTO_INCREMENT" + return "AUTO_INCREMENT" } func (db *mysql) SupportCharset() bool { - return true + return true } func (db *mysql) IndexOnTable() bool { - return true + return true } func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{db.dbname, tableName, idxName} - sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" - sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" - return sql, args + args := []interface{}{db.dbName, tableName, idxName} + sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" + sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" + return sql, args } func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{db.dbname, tableName, colName} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - return sql, args + args := []interface{}{db.dbName, tableName, colName} + sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" + return sql, args } func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{db.dbname, tableName} - sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return sql, args + args := []interface{}{db.dbName, tableName} + sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" + return sql, args } func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{db.dbname, tableName} - s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + - " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, record := range res { - col := new(Column) - col.Indexes = make(map[string]bool) - for name, content := range record { - switch name { - case "COLUMN_NAME": - col.Name = strings.Trim(string(content), "` ") - case "IS_NULLABLE": - if "YES" == string(content) { - col.Nullable = true - } - case "COLUMN_DEFAULT": - // add '' - col.Default = string(content) - case "COLUMN_TYPE": - cts := strings.Split(string(content), "(") - var len1, len2 int - if len(cts) == 2 { - idx := strings.Index(cts[1], ")") - lens := strings.Split(cts[1][0:idx], ",") - len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) - if err != nil { - return nil, nil, err - } - if len(lens) == 2 { - len2, err = strconv.Atoi(lens[1]) - if err != nil { - return nil, nil, err - } - } - } - colName := cts[0] - colType := strings.ToUpper(colName) - col.Length = len1 - col.Length2 = len2 - if _, ok := sqlTypes[colType]; ok { - col.SQLType = SQLType{colType, len1, len2} - } else { - return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) - } - case "COLUMN_KEY": - key := string(content) - if key == "PRI" { - col.IsPrimaryKey = true - } - if key == "UNI" { - //col.is - } - case "EXTRA": - extra := string(content) - if extra == "auto_increment" { - col.IsAutoIncrement = true - } - } - } - if col.SQLType.IsText() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } - return colSeq, cols, nil + args := []interface{}{db.dbName, tableName} + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, record := range res { + col := new(Column) + col.Indexes = make(map[string]bool) + for name, content := range record { + switch name { + case "COLUMN_NAME": + col.Name = strings.Trim(string(content), "` ") + case "IS_NULLABLE": + if "YES" == string(content) { + col.Nullable = true + } + case "COLUMN_DEFAULT": + // add '' + col.Default = string(content) + case "COLUMN_TYPE": + cts := strings.Split(string(content), "(") + var len1, len2 int + if len(cts) == 2 { + idx := strings.Index(cts[1], ")") + lens := strings.Split(cts[1][0:idx], ",") + len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) + if err != nil { + return nil, nil, err + } + if len(lens) == 2 { + len2, err = strconv.Atoi(lens[1]) + if err != nil { + return nil, nil, err + } + } + } + colName := cts[0] + colType := strings.ToUpper(colName) + col.Length = len1 + col.Length2 = len2 + if _, ok := sqlTypes[colType]; ok { + col.SQLType = SQLType{colType, len1, len2} + } else { + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) + } + case "COLUMN_KEY": + key := string(content) + if key == "PRI" { + col.IsPrimaryKey = true + } + if key == "UNI" { + //col.is + } + case "EXTRA": + extra := string(content) + if extra == "auto_increment" { + col.IsAutoIncrement = true + } + } + } + if col.SQLType.IsText() { + if col.Default != "" { + col.Default = "'" + col.Default + "'" + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + return colSeq, cols, nil } func (db *mysql) GetTables() ([]*Table, error) { - args := []interface{}{db.dbname} - s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{db.dbName} + s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "TABLE_NAME": - table.Name = strings.Trim(string(content), "` ") - case "ENGINE": - } - } - tables = append(tables, table) - } - return tables, nil + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "TABLE_NAME": + table.Name = strings.Trim(string(content), "` ") + case "ENGINE": + } + } + tables = append(tables, table) + } + return tables, nil } func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{db.dbname, tableName} - s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{db.dbName, tableName} + s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - indexes := make(map[string]*Index, 0) - for _, record := range res { - var indexType int - var indexName, colName string - for name, content := range record { - switch name { - case "NON_UNIQUE": - if "YES" == string(content) || string(content) == "1" { - indexType = IndexType - } else { - indexType = UniqueType - } - case "INDEX_NAME": - indexName = string(content) - case "COLUMN_NAME": - colName = strings.Trim(string(content), "` ") - } - } - if indexName == "PRIMARY" { - continue - } - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - indexName = indexName[5+len(tableName) : len(indexName)] - } + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName, colName string + for name, content := range record { + switch name { + case "NON_UNIQUE": + if "YES" == string(content) || string(content) == "1" { + indexType = IndexType + } else { + indexType = UniqueType + } + case "INDEX_NAME": + indexName = string(content) + case "COLUMN_NAME": + colName = strings.Trim(string(content), "` ") + } + } + if indexName == "PRIMARY" { + continue + } + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + indexName = indexName[5+len(tableName) : len(indexName)] + } - var index *Index - var ok bool - if index, ok = indexes[indexName]; !ok { - index = new(Index) - index.Type = indexType - index.Name = indexName - indexes[indexName] = index - } - index.AddColumn(colName) - } - return indexes, nil + var index *Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(Index) + index.Type = indexType + index.Name = indexName + indexes[indexName] = index + } + index.AddColumn(colName) + } + return indexes, nil } diff --git a/pool.go b/pool.go index 78502802..7a7b173e 100644 --- a/pool.go +++ b/pool.go @@ -1,13 +1,13 @@ package xorm import ( - "database/sql" - //"fmt" - "sync" - //"sync/atomic" - "container/list" - "reflect" - "time" + "database/sql" + //"fmt" + "sync" + //"sync/atomic" + "container/list" + "reflect" + "time" ) // Interface IConnecPool is a connection pool interface, all implements should implement @@ -17,14 +17,14 @@ import ( // ReleaseDB for releasing a db connection; // Close for invoking when engine.Close type IConnectPool interface { - Init(engine *Engine) error - RetrieveDB(engine *Engine) (*sql.DB, error) - ReleaseDB(engine *Engine, db *sql.DB) - Close(engine *Engine) error - SetMaxIdleConns(conns int) - MaxIdleConns() int - SetMaxConns(conns int) - MaxConns() int + Init(engine *Engine) error + RetrieveDB(engine *Engine) (*sql.DB, error) + ReleaseDB(engine *Engine, db *sql.DB) + Close(engine *Engine) error + SetMaxIdleConns(conns int) + MaxIdleConns() int + SetMaxConns(conns int) + MaxConns() int } // Struct NoneConnectPool is a implement for IConnectPool. It provides directly invoke driver's @@ -34,35 +34,35 @@ type NoneConnectPool struct { // NewNoneConnectPool new a NoneConnectPool. func NewNoneConnectPool() IConnectPool { - return &NoneConnectPool{} + return &NoneConnectPool{} } // Init do nothing func (p *NoneConnectPool) Init(engine *Engine) error { - return nil + return nil } // RetrieveDB directly open a connection func (p *NoneConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { - db, err = engine.OpenDB() - return + db, err = engine.OpenDB() + return } // ReleaseDB directly close a connection func (p *NoneConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { - db.Close() + db.Close() } // Close do nothing func (p *NoneConnectPool) Close(engine *Engine) error { - return nil + return nil } func (p *NoneConnectPool) SetMaxIdleConns(conns int) { } func (p *NoneConnectPool) MaxIdleConns() int { - return 0 + return 0 } // not implemented @@ -71,133 +71,133 @@ func (p *NoneConnectPool) SetMaxConns(conns int) { // not implemented func (p *NoneConnectPool) MaxConns() int { - return -1 + return -1 } // Struct SysConnectPool is a simple wrapper for using system default connection pool. // About the system connection pool, you can review the code database/sql/sql.go // It's currently default Pool implments. type SysConnectPool struct { - db *sql.DB - maxIdleConns int - maxConns int - curConns int - mutex *sync.Mutex - queue *list.List + db *sql.DB + maxIdleConns int + maxConns int + curConns int + mutex *sync.Mutex + queue *list.List } // NewSysConnectPool new a SysConnectPool. func NewSysConnectPool() IConnectPool { - return &SysConnectPool{} + return &SysConnectPool{} } // Init create a db immediately and keep it util engine closed. func (s *SysConnectPool) Init(engine *Engine) error { - db, err := engine.OpenDB() - if err != nil { - return err - } - s.db = db - s.maxIdleConns = 2 - s.maxConns = -1 - s.curConns = 0 - s.mutex = &sync.Mutex{} - s.queue = list.New() - return nil + db, err := engine.OpenDB() + if err != nil { + return err + } + s.db = db + s.maxIdleConns = 2 + s.maxConns = -1 + s.curConns = 0 + s.mutex = &sync.Mutex{} + s.queue = list.New() + return nil } type node struct { - mutex sync.Mutex - cond *sync.Cond + mutex sync.Mutex + cond *sync.Cond } func newCondNode() *node { - n := &node{} - n.cond = sync.NewCond(&n.mutex) - return n + n := &node{} + n.cond = sync.NewCond(&n.mutex) + return n } // RetrieveDB just return the only db func (s *SysConnectPool) RetrieveDB(engine *Engine) (db *sql.DB, err error) { - /*if s.maxConns > 0 { - fmt.Println("before retrieve") - s.mutex.Lock() - for s.curConns >= s.maxConns { - fmt.Println("before waiting...", s.curConns, s.queue.Len()) - s.mutex.Unlock() - n := NewNode() - n.cond.L.Lock() - s.queue.PushBack(n) - n.cond.Wait() - n.cond.L.Unlock() - s.mutex.Lock() - fmt.Println("after waiting...", s.curConns, s.queue.Len()) - } - s.curConns += 1 - s.mutex.Unlock() - fmt.Println("after retrieve") - }*/ - return s.db, nil + /*if s.maxConns > 0 { + fmt.Println("before retrieve") + s.mutex.Lock() + for s.curConns >= s.maxConns { + fmt.Println("before waiting...", s.curConns, s.queue.Len()) + s.mutex.Unlock() + n := NewNode() + n.cond.L.Lock() + s.queue.PushBack(n) + n.cond.Wait() + n.cond.L.Unlock() + s.mutex.Lock() + fmt.Println("after waiting...", s.curConns, s.queue.Len()) + } + s.curConns += 1 + s.mutex.Unlock() + fmt.Println("after retrieve") + }*/ + return s.db, nil } // ReleaseDB do nothing func (s *SysConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { - /*if s.maxConns > 0 { - s.mutex.Lock() - fmt.Println("before release", s.queue.Len()) - s.curConns -= 1 + /*if s.maxConns > 0 { + s.mutex.Lock() + fmt.Println("before release", s.queue.Len()) + s.curConns -= 1 - if e := s.queue.Front(); e != nil { - n := e.Value.(*node) - //n.cond.L.Lock() - n.cond.Signal() - fmt.Println("signaled...") - s.queue.Remove(e) - //n.cond.L.Unlock() - } - fmt.Println("after released", s.queue.Len()) - s.mutex.Unlock() - }*/ + if e := s.queue.Front(); e != nil { + n := e.Value.(*node) + //n.cond.L.Lock() + n.cond.Signal() + fmt.Println("signaled...") + s.queue.Remove(e) + //n.cond.L.Unlock() + } + fmt.Println("after released", s.queue.Len()) + s.mutex.Unlock() + }*/ } // Close closed the only db func (p *SysConnectPool) Close(engine *Engine) error { - return p.db.Close() + return p.db.Close() } func (p *SysConnectPool) SetMaxIdleConns(conns int) { - p.db.SetMaxIdleConns(conns) - p.maxIdleConns = conns + p.db.SetMaxIdleConns(conns) + p.maxIdleConns = conns } func (p *SysConnectPool) MaxIdleConns() int { - return p.maxIdleConns + return p.maxIdleConns } // not implemented func (p *SysConnectPool) SetMaxConns(conns int) { - p.maxConns = conns - // if support SetMaxOpenConns, go 1.2+, then set - if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() { - reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)}) - } - //p.db.SetMaxOpenConns(conns) + p.maxConns = conns + // if support SetMaxOpenConns, go 1.2+, then set + if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() { + reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)}) + } + //p.db.SetMaxOpenConns(conns) } // not implemented func (p *SysConnectPool) MaxConns() int { - return p.maxConns + return p.maxConns } // NewSimpleConnectPool new a SimpleConnectPool func NewSimpleConnectPool() IConnectPool { - return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10), - usingConnects: map[*sql.DB]time.Time{}, - cur: -1, - maxWaitTimeOut: 14400, - maxIdleConns: 10, - mutex: &sync.Mutex{}, - } + return &SimpleConnectPool{releasedConnects: make([]*sql.DB, 10), + usingConnects: map[*sql.DB]time.Time{}, + cur: -1, + maxWaitTimeOut: 14400, + maxIdleConns: 10, + mutex: &sync.Mutex{}, + } } // Struct SimpleConnectPool is a simple implementation for IConnectPool. @@ -205,75 +205,75 @@ func NewSimpleConnectPool() IConnectPool { // Opening or Closing a database connection must be enter a lock. // This implements will be improved in furture. type SimpleConnectPool struct { - releasedConnects []*sql.DB - cur int - usingConnects map[*sql.DB]time.Time - maxWaitTimeOut int - mutex *sync.Mutex - maxIdleConns int + releasedConnects []*sql.DB + cur int + usingConnects map[*sql.DB]time.Time + maxWaitTimeOut int + mutex *sync.Mutex + maxIdleConns int } func (s *SimpleConnectPool) Init(engine *Engine) error { - return nil + return nil } // RetrieveDB get a connection from connection pool func (p *SimpleConnectPool) RetrieveDB(engine *Engine) (*sql.DB, error) { - p.mutex.Lock() - defer p.mutex.Unlock() - var db *sql.DB = nil - var err error = nil - //fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) - if p.cur < 0 { - db, err = engine.OpenDB() - if err != nil { - return nil, err - } - p.usingConnects[db] = time.Now() - } else { - db = p.releasedConnects[p.cur] - p.usingConnects[db] = time.Now() - p.releasedConnects[p.cur] = nil - p.cur = p.cur - 1 - } + p.mutex.Lock() + defer p.mutex.Unlock() + var db *sql.DB = nil + var err error = nil + //fmt.Printf("%x, rbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) + if p.cur < 0 { + db, err = engine.OpenDB() + if err != nil { + return nil, err + } + p.usingConnects[db] = time.Now() + } else { + db = p.releasedConnects[p.cur] + p.usingConnects[db] = time.Now() + p.releasedConnects[p.cur] = nil + p.cur = p.cur - 1 + } - //fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) - return db, nil + //fmt.Printf("%x, rend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) + return db, nil } // ReleaseDB release a db from connection pool func (p *SimpleConnectPool) ReleaseDB(engine *Engine, db *sql.DB) { - p.mutex.Lock() - defer p.mutex.Unlock() - //fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) - if p.cur >= p.maxIdleConns-1 { - db.Close() - } else { - p.cur = p.cur + 1 - p.releasedConnects[p.cur] = db - } - delete(p.usingConnects, db) - //fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) + p.mutex.Lock() + defer p.mutex.Unlock() + //fmt.Printf("%x, lbegin - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) + if p.cur >= p.maxIdleConns-1 { + db.Close() + } else { + p.cur = p.cur + 1 + p.releasedConnects[p.cur] = db + } + delete(p.usingConnects, db) + //fmt.Printf("%x, lend - released:%v, using:%v\n", &p, p.cur+1, len(p.usingConnects)) } // Close release all db func (p *SimpleConnectPool) Close(engine *Engine) error { - p.mutex.Lock() - defer p.mutex.Unlock() - for len(p.releasedConnects) > 0 { - p.releasedConnects[0].Close() - p.releasedConnects = p.releasedConnects[1:] - } + p.mutex.Lock() + defer p.mutex.Unlock() + for len(p.releasedConnects) > 0 { + p.releasedConnects[0].Close() + p.releasedConnects = p.releasedConnects[1:] + } - return nil + return nil } func (p *SimpleConnectPool) SetMaxIdleConns(conns int) { - p.maxIdleConns = conns + p.maxIdleConns = conns } func (p *SimpleConnectPool) MaxIdleConns() int { - return p.maxIdleConns + return p.maxIdleConns } // not implemented @@ -282,5 +282,5 @@ func (p *SimpleConnectPool) SetMaxConns(conns int) { // not implemented func (p *SimpleConnectPool) MaxConns() int { - return -1 + return -1 } diff --git a/postgres.go b/postgres.go index 7b716c06..c316f9b5 100644 --- a/postgres.go +++ b/postgres.go @@ -1,300 +1,305 @@ package xorm import ( - "database/sql" - "errors" - "fmt" - "strconv" - "strings" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" ) type postgres struct { - base - dbname string + base } type values map[string]string func (vs values) Set(k, v string) { - vs[k] = v + vs[k] = v } func (vs values) Get(k string) (v string) { - return vs[k] + return vs[k] } func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) + panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } func parseOpts(name string, o values) { - if len(name) == 0 { - return - } + if len(name) == 0 { + return + } - name = strings.TrimSpace(name) + name = strings.TrimSpace(name) - ps := strings.Split(name, " ") - for _, p := range ps { - kv := strings.Split(p, "=") - if len(kv) < 2 { - errorf("invalid option: %q", p) - } - o.Set(kv[0], kv[1]) - } + ps := strings.Split(name, " ") + for _, p := range ps { + kv := strings.Split(p, "=") + if len(kv) < 2 { + errorf("invalid option: %q", p) + } + o.Set(kv[0], kv[1]) + } +} + +type postgresParser struct { +} + +func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) { + db := &uri{dbType: POSTGRES} + o := make(values) + parseOpts(dataSourceName, o) + + db.dbName = o.Get("dbname") + if db.dbName == "" { + return nil, errors.New("dbname is empty") + } + return db, nil } func (db *postgres) Init(drivername, uri string) error { - db.base.init(drivername, uri) - - o := make(values) - parseOpts(uri, o) - - db.dbname = o.Get("dbname") - if db.dbname == "" { - return errors.New("dbname is empty") - } - return nil + return db.base.init(&postgresParser{}, drivername, uri) } func (db *postgres) SqlType(c *Column) string { - var res string - switch t := c.SQLType.Name; t { - case TinyInt: - res = SmallInt - case MediumInt, Int, Integer: - return Integer - case Serial, BigSerial: - c.IsAutoIncrement = true - c.Nullable = false - res = t - case Binary, VarBinary: - return Bytea - case DateTime: - res = TimeStamp - case TimeStampz: - return "timestamp with time zone" - case Float: - res = Real - case TinyText, MediumText, LongText: - res = Text - case Blob, TinyBlob, MediumBlob, LongBlob: - return Bytea - case Double: - return "DOUBLE PRECISION" - default: - if c.IsAutoIncrement { - return Serial - } - res = t - } + var res string + switch t := c.SQLType.Name; t { + case TinyInt: + res = SmallInt + case MediumInt, Int, Integer: + return Integer + case Serial, BigSerial: + c.IsAutoIncrement = true + c.Nullable = false + res = t + case Binary, VarBinary: + return Bytea + case DateTime: + res = TimeStamp + case TimeStampz: + return "timestamp with time zone" + case Float: + res = Real + case TinyText, MediumText, LongText: + res = Text + case Blob, TinyBlob, MediumBlob, LongBlob: + return Bytea + case Double: + return "DOUBLE PRECISION" + default: + if c.IsAutoIncrement { + return Serial + } + res = t + } - var hasLen1 bool = (c.Length > 0) - var hasLen2 bool = (c.Length2 > 0) - if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" - } else if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" - } - return res + var hasLen1 bool = (c.Length > 0) + var hasLen2 bool = (c.Length2 > 0) + if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } else if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } + return res } func (db *postgres) SupportInsertMany() bool { - return true + return true } func (db *postgres) QuoteStr() string { - return "\"" + return "\"" } func (db *postgres) AutoIncrStr() string { - return "" + return "" } func (db *postgres) SupportEngine() bool { - return false + return false } func (db *postgres) SupportCharset() bool { - return false + return false } func (db *postgres) IndexOnTable() bool { - return false + return false } func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{tableName, idxName} - return `SELECT indexname FROM pg_indexes ` + - `WHERE tablename = ? AND indexname = ?`, args + args := []interface{}{tableName, idxName} + return `SELECT indexname FROM pg_indexes ` + + `WHERE tablename = ? AND indexname = ?`, args } func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + args := []interface{}{tableName} + return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args } func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args + args := []interface{}{tableName, colName} + return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + + " AND column_name = ?", args } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{tableName} - s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + - ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + args := []interface{}{tableName} + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + + ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, record := range res { - col := new(Column) - col.Indexes = make(map[string]bool) - for name, content := range record { - switch name { - case "column_name": - col.Name = strings.Trim(string(content), `" `) - case "column_default": - if strings.HasPrefix(string(content), "nextval") { - col.IsPrimaryKey = true - } else { - col.Default = string(content) - } - case "is_nullable": - if string(content) == "YES" { - col.Nullable = true - } else { - col.Nullable = false - } - case "data_type": - ct := string(content) - switch ct { - case "character varying", "character": - col.SQLType = SQLType{Varchar, 0, 0} - case "timestamp without time zone": - col.SQLType = SQLType{DateTime, 0, 0} - case "timestamp with time zone": - col.SQLType = SQLType{TimeStampz, 0, 0} - case "double precision": - col.SQLType = SQLType{Double, 0, 0} - case "boolean": - col.SQLType = SQLType{Bool, 0, 0} - case "time without time zone": - col.SQLType = SQLType{Time, 0, 0} - default: - col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} - } - if _, ok := sqlTypes[col.SQLType.Name]; !ok { - return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) - } - case "character_maximum_length": - i, err := strconv.Atoi(string(content)) - if err != nil { - return nil, nil, errors.New("retrieve length error") - } - col.Length = i - case "numeric_precision": - case "numeric_precision_radix": - } - } - if col.SQLType.IsText() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, record := range res { + col := new(Column) + col.Indexes = make(map[string]bool) + for name, content := range record { + switch name { + case "column_name": + col.Name = strings.Trim(string(content), `" `) + case "column_default": + if strings.HasPrefix(string(content), "nextval") { + col.IsPrimaryKey = true + } else { + col.Default = string(content) + } + case "is_nullable": + if string(content) == "YES" { + col.Nullable = true + } else { + col.Nullable = false + } + case "data_type": + ct := string(content) + switch ct { + case "character varying", "character": + col.SQLType = SQLType{Varchar, 0, 0} + case "timestamp without time zone": + col.SQLType = SQLType{DateTime, 0, 0} + case "timestamp with time zone": + col.SQLType = SQLType{TimeStampz, 0, 0} + case "double precision": + col.SQLType = SQLType{Double, 0, 0} + case "boolean": + col.SQLType = SQLType{Bool, 0, 0} + case "time without time zone": + col.SQLType = SQLType{Time, 0, 0} + default: + col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} + } + if _, ok := sqlTypes[col.SQLType.Name]; !ok { + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) + } + case "character_maximum_length": + i, err := strconv.Atoi(string(content)) + if err != nil { + return nil, nil, errors.New("retrieve length error") + } + col.Length = i + case "numeric_precision": + case "numeric_precision_radix": + } + } + if col.SQLType.IsText() { + if col.Default != "" { + col.Default = "'" + col.Default + "'" + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } - return colSeq, cols, nil + return colSeq, cols, nil } func (db *postgres) GetTables() ([]*Table, error) { - args := []interface{}{} - s := "SELECT tablename FROM pg_tables where schemaname = 'public'" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{} + s := "SELECT tablename FROM pg_tables where schemaname = 'public'" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "tablename": - table.Name = string(content) - } - } - tables = append(tables, table) - } - return tables, nil + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "tablename": + table.Name = string(content) + } + } + tables = append(tables, table) + } + return tables, nil } func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{tableName} - s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" + args := []interface{}{tableName} + s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - indexes := make(map[string]*Index, 0) - for _, record := range res { - var indexType int - var indexName string - var colNames []string + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName string + var colNames []string - for name, content := range record { - switch name { - case "indexname": - indexName = strings.Trim(string(content), `" `) - case "indexdef": - c := string(content) - if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { - indexType = UniqueType - } else { - indexType = IndexType - } - cs := strings.Split(c, "(") - colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") - } - } - if strings.HasSuffix(indexName, "_pkey") { - continue - } - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - newIdxName := indexName[5+len(tableName) : len(indexName)] - if newIdxName != "" { - indexName = newIdxName - } - } + for name, content := range record { + switch name { + case "indexname": + indexName = strings.Trim(string(content), `" `) + case "indexdef": + c := string(content) + if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { + indexType = UniqueType + } else { + indexType = IndexType + } + cs := strings.Split(c, "(") + colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") + } + } + if strings.HasSuffix(indexName, "_pkey") { + continue + } + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + newIdxName := indexName[5+len(tableName) : len(indexName)] + if newIdxName != "" { + indexName = newIdxName + } + } - index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} - for _, colName := range colNames { - index.Cols = append(index.Cols, strings.Trim(colName, `" `)) - } - indexes[index.Name] = index - } - return indexes, nil + index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + for _, colName := range colNames { + index.Cols = append(index.Cols, strings.Trim(colName, `" `)) + } + indexes[index.Name] = index + } + return indexes, nil } diff --git a/processors.go b/processors.go index d1ea25ef..770515e6 100644 --- a/processors.go +++ b/processors.go @@ -2,17 +2,17 @@ package xorm // Executed before an object is initially persisted to the database type BeforeInsertProcessor interface { - BeforeInsert() + BeforeInsert() } // Executed before an object is updated type BeforeUpdateProcessor interface { - BeforeUpdate() + BeforeUpdate() } // Executed before an object is deleted type BeforeDeleteProcessor interface { - BeforeDelete() + BeforeDelete() } // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations @@ -24,16 +24,15 @@ type BeforeDeleteProcessor interface { // Executed after an object is persisted to the database type AfterInsertProcessor interface { - AfterInsert() + AfterInsert() } // Executed after an object has been updated type AfterUpdateProcessor interface { - AfterUpdate() + AfterUpdate() } // Executed after an object has been deleted type AfterDeleteProcessor interface { - AfterDelete() + AfterDelete() } - diff --git a/session.go b/session.go index 01853736..c8de127d 100644 --- a/session.go +++ b/session.go @@ -1,132 +1,132 @@ package xorm import ( - "database/sql" - "encoding/json" - "errors" - "fmt" - "reflect" - "strconv" - "strings" - "time" + "database/sql" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" ) // Struct Session keep a pointer to sql.DB and provides all execution of all // kind of database operations. type Session struct { - Db *sql.DB - Engine *Engine - Tx *sql.Tx - Statement Statement - IsAutoCommit bool - IsCommitedOrRollbacked bool - TransType string - IsAutoClose bool + Db *sql.DB + Engine *Engine + Tx *sql.Tx + Statement Statement + IsAutoCommit bool + IsCommitedOrRollbacked bool + TransType string + IsAutoClose bool - // !nashtsai! storing these beans due to yet committed tx - // afterInsertBeans []interface{} - // afterUpdateBeans []interface{} - // afterDeleteBeans []interface{} - afterInsertBeans map[interface{}]*[]func(interface{}) - afterUpdateBeans map[interface{}]*[]func(interface{}) - afterDeleteBeans map[interface{}]*[]func(interface{}) - // -- + // !nashtsai! storing these beans due to yet committed tx + // afterInsertBeans []interface{} + // afterUpdateBeans []interface{} + // afterDeleteBeans []interface{} + afterInsertBeans map[interface{}]*[]func(interface{}) + afterUpdateBeans map[interface{}]*[]func(interface{}) + afterDeleteBeans map[interface{}]*[]func(interface{}) + // -- - beforeClosures []func(interface{}) - afterClosures []func(interface{}) + beforeClosures []func(interface{}) + afterClosures []func(interface{}) } // Method Init reset the session as the init status. func (session *Session) Init() { - session.Statement = Statement{Engine: session.Engine} - session.Statement.Init() - session.IsAutoCommit = true - session.IsCommitedOrRollbacked = false - session.IsAutoClose = false + session.Statement = Statement{Engine: session.Engine} + session.Statement.Init() + session.IsAutoCommit = true + session.IsCommitedOrRollbacked = false + session.IsAutoClose = false - // !nashtsai! is lazy init better? - session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterUpdateBeans = make(map[interface{}]*[]func(interface{}), 0) - session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) - session.beforeClosures = make([]func(interface{}), 0) - session.afterClosures = make([]func(interface{}), 0) + // !nashtsai! is lazy init better? + session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) + session.afterUpdateBeans = make(map[interface{}]*[]func(interface{}), 0) + session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) + session.beforeClosures = make([]func(interface{}), 0) + session.afterClosures = make([]func(interface{}), 0) } // Method Close release the connection from pool func (session *Session) Close() { - defer func() { - if session.Db != nil { - session.Engine.Pool.ReleaseDB(session.Engine, session.Db) - session.Db = nil - session.Tx = nil - session.Init() - } - }() + defer func() { + if session.Db != nil { + session.Engine.Pool.ReleaseDB(session.Engine, session.Db) + session.Db = nil + session.Tx = nil + session.Init() + } + }() } // Method Sql provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use Sql. func (session *Session) Sql(querystring string, args ...interface{}) *Session { - session.Statement.Sql(querystring, args...) - return session + session.Statement.Sql(querystring, args...) + return session } // Method Where provides custom query condition. func (session *Session) Where(querystring string, args ...interface{}) *Session { - session.Statement.Where(querystring, args...) - return session + session.Statement.Where(querystring, args...) + return session } // Method Where provides custom query condition. func (session *Session) And(querystring string, args ...interface{}) *Session { - session.Statement.And(querystring, args...) - return session + session.Statement.And(querystring, args...) + return session } // Method Where provides custom query condition. func (session *Session) Or(querystring string, args ...interface{}) *Session { - session.Statement.Or(querystring, args...) - return session + session.Statement.Or(querystring, args...) + return session } // Method Id provides converting id as a query condition func (session *Session) Id(id interface{}) *Session { - session.Statement.Id(id) - return session + session.Statement.Id(id) + return session } // Apply before Processor, affected bean is passed to closure arg func (session *Session) Before(closures func(interface{})) *Session { - if closures != nil { - session.beforeClosures = append(session.beforeClosures, closures) - } - return session + if closures != nil { + session.beforeClosures = append(session.beforeClosures, closures) + } + return session } // Apply after Processor, affected bean is passed to closure arg func (session *Session) After(closures func(interface{})) *Session { - if closures != nil { - session.afterClosures = append(session.afterClosures, closures) - } - return session + if closures != nil { + session.afterClosures = append(session.afterClosures, closures) + } + return session } // Method Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { - session.Statement.Table(tableNameOrBean) - return session + session.Statement.Table(tableNameOrBean) + return session } // Method In provides a query string like "id in (1, 2, 3)" func (session *Session) In(column string, args ...interface{}) *Session { - session.Statement.In(column, args...) - return session + session.Statement.In(column, args...) + return session } // Method Cols provides some columns to special func (session *Session) Cols(columns ...string) *Session { - session.Statement.Cols(columns...) - return session + session.Statement.Cols(columns...) + return session } // Xorm automatically retrieve condition according struct, but @@ -135,674 +135,674 @@ func (session *Session) Cols(columns ...string) *Session { // If no paramters, it will use all the bool field of struct, or // it will use paramters's columns func (session *Session) UseBool(columns ...string) *Session { - session.Statement.UseBool(columns...) - return session + session.Statement.UseBool(columns...) + return session } // use for distinct columns. Caution: when you are using cache, // distinct will not be cached because cache system need id, // but distinct will not provide id func (session *Session) Distinct(columns ...string) *Session { - session.Statement.Distinct(columns...) - return session + session.Statement.Distinct(columns...) + return session } // Only not use the paramters as select or update columns func (session *Session) Omit(columns ...string) *Session { - session.Statement.Omit(columns...) - return session + session.Statement.Omit(columns...) + return session } // Method NoAutoTime means do not automatically give created field and updated field // the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { - session.Statement.UseAutoTime = false - return session + session.Statement.UseAutoTime = false + return session } // Method Limit provide limit and offset query condition func (session *Session) Limit(limit int, start ...int) *Session { - session.Statement.Limit(limit, start...) - return session + session.Statement.Limit(limit, start...) + return session } // Method OrderBy provide order by query condition, the input parameter is the content // after order by on a sql statement. func (session *Session) OrderBy(order string) *Session { - session.Statement.OrderBy(order) - return session + session.Statement.OrderBy(order) + return session } // Method Desc provide desc order by query condition, the input parameters are columns. func (session *Session) Desc(colNames ...string) *Session { - if session.Statement.OrderStr != "" { - session.Statement.OrderStr += ", " - } - newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" - return session + if session.Statement.OrderStr != "" { + session.Statement.OrderStr += ", " + } + newColNames := col2NewCols(colNames...) + sql := strings.Join(newColNames, session.Engine.Quote(" DESC, ")) + session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" + return session } // Method Asc provide asc order by query condition, the input parameters are columns. func (session *Session) Asc(colNames ...string) *Session { - if session.Statement.OrderStr != "" { - session.Statement.OrderStr += ", " - } - newColNames := col2NewCols(colNames...) - sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) - session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" - return session + if session.Statement.OrderStr != "" { + session.Statement.OrderStr += ", " + } + newColNames := col2NewCols(colNames...) + sql := strings.Join(newColNames, session.Engine.Quote(" ASC, ")) + session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" + return session } // Method StoreEngine is only avialble mysql dialect currently func (session *Session) StoreEngine(storeEngine string) *Session { - session.Statement.StoreEngine = storeEngine - return session + session.Statement.StoreEngine = storeEngine + return session } // Method StoreEngine is only avialble charset dialect currently func (session *Session) Charset(charset string) *Session { - session.Statement.Charset = charset - return session + session.Statement.Charset = charset + return session } // Method Cascade indicates if loading sub Struct func (session *Session) Cascade(trueOrFalse ...bool) *Session { - if len(trueOrFalse) >= 1 { - session.Statement.UseCascade = trueOrFalse[0] - } - return session + if len(trueOrFalse) >= 1 { + session.Statement.UseCascade = trueOrFalse[0] + } + return session } // Method NoCache ask this session do not retrieve data from cache system and // get data from database directly. func (session *Session) NoCache() *Session { - session.Statement.UseCache = false - return session + session.Statement.UseCache = false + return session } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (session *Session) Join(join_operator, tablename, condition string) *Session { - session.Statement.Join(join_operator, tablename, condition) - return session + session.Statement.Join(join_operator, tablename, condition) + return session } // Generate Group By statement func (session *Session) GroupBy(keys string) *Session { - session.Statement.GroupBy(keys) - return session + session.Statement.GroupBy(keys) + return session } // Generate Having statement func (session *Session) Having(conditions string) *Session { - session.Statement.Having(conditions) - return session + session.Statement.Having(conditions) + return session } func (session *Session) newDb() error { - if session.Db == nil { - db, err := session.Engine.Pool.RetrieveDB(session.Engine) - if err != nil { - return err - } - session.Db = db - } - return nil + if session.Db == nil { + db, err := session.Engine.Pool.RetrieveDB(session.Engine) + if err != nil { + return err + } + session.Db = db + } + return nil } // Begin a transaction func (session *Session) Begin() error { - err := session.newDb() - if err != nil { - return err - } - if session.IsAutoCommit { - tx, err := session.Db.Begin() - if err != nil { - return err - } - session.IsAutoCommit = false - session.IsCommitedOrRollbacked = false - session.Tx = tx + err := session.newDb() + if err != nil { + return err + } + if session.IsAutoCommit { + tx, err := session.Db.Begin() + if err != nil { + return err + } + session.IsAutoCommit = false + session.IsCommitedOrRollbacked = false + session.Tx = tx - session.Engine.LogSQL("BEGIN TRANSACTION") - } - return nil + session.Engine.LogSQL("BEGIN TRANSACTION") + } + return nil } // When using transaction, you can rollback if any error func (session *Session) Rollback() error { - if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.LogSQL("ROLL BACK") - session.IsCommitedOrRollbacked = true - return session.Tx.Rollback() - } - return nil + if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { + session.Engine.LogSQL("ROLL BACK") + session.IsCommitedOrRollbacked = true + return session.Tx.Rollback() + } + return nil } // When using transaction, Commit will commit all operations. func (session *Session) Commit() error { - if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.LogSQL("COMMIT") - session.IsCommitedOrRollbacked = true - var err error - if err = session.Tx.Commit(); err == nil { - // handle processors after tx committed + if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { + session.Engine.LogSQL("COMMIT") + session.IsCommitedOrRollbacked = true + var err error + if err = session.Tx.Commit(); err == nil { + // handle processors after tx committed - closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { + closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { - if closuresPtr != nil { - for _, closure := range *closuresPtr { - closure(bean) - } - } - } + if closuresPtr != nil { + for _, closure := range *closuresPtr { + closure(bean) + } + } + } - for bean, closuresPtr := range session.afterInsertBeans { - closureCallFunc(closuresPtr, bean) + for bean, closuresPtr := range session.afterInsertBeans { + closureCallFunc(closuresPtr, bean) - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - processor.AfterInsert() - } - } - for bean, closuresPtr := range session.afterUpdateBeans { - closureCallFunc(closuresPtr, bean) + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + processor.AfterInsert() + } + } + for bean, closuresPtr := range session.afterUpdateBeans { + closureCallFunc(closuresPtr, bean) - if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - processor.AfterUpdate() - } - } - for bean, closuresPtr := range session.afterDeleteBeans { - closureCallFunc(closuresPtr, bean) + if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { + processor.AfterUpdate() + } + } + for bean, closuresPtr := range session.afterDeleteBeans { + closureCallFunc(closuresPtr, bean) - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } - cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { - if len(*slices) > 0 { - *slices = make(map[interface{}]*[]func(interface{}), 0) - } - } - cleanUpFunc(&session.afterInsertBeans) - cleanUpFunc(&session.afterUpdateBeans) - cleanUpFunc(&session.afterDeleteBeans) - } - return err - } - return nil + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() + } + } + cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { + if len(*slices) > 0 { + *slices = make(map[interface{}]*[]func(interface{}), 0) + } + } + cleanUpFunc(&session.afterInsertBeans) + cleanUpFunc(&session.afterUpdateBeans) + cleanUpFunc(&session.afterDeleteBeans) + } + return err + } + return nil } func cleanupProcessorsClosures(slices *[]func(interface{})) { - if len(*slices) > 0 { - *slices = make([]func(interface{}), 0) - } + if len(*slices) > 0 { + *slices = make([]func(interface{}), 0) + } } func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { - dataStruct := reflect.Indirect(reflect.ValueOf(obj)) - if dataStruct.Kind() != reflect.Struct { - return errors.New("Expected a pointer to a struct") - } + dataStruct := reflect.Indirect(reflect.ValueOf(obj)) + if dataStruct.Kind() != reflect.Struct { + return errors.New("Expected a pointer to a struct") + } - table := session.Engine.autoMapType(rType(obj)) + table := session.Engine.autoMapType(rType(obj)) - for key, data := range objMap { - key = strings.ToLower(key) - if _, ok := table.Columns[key]; !ok { - session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) - continue - } - col := table.Columns[key] - fieldName := col.FieldName - fieldPath := strings.Split(fieldName, ".") - var fieldValue reflect.Value - if len(fieldPath) > 2 { - session.Engine.LogError("Unsupported mutliderive", fieldName) - continue - } else if len(fieldPath) == 2 { - parentField := dataStruct.FieldByName(fieldPath[0]) - if parentField.IsValid() { - fieldValue = parentField.FieldByName(fieldPath[1]) - } - } else { - fieldValue = dataStruct.FieldByName(fieldName) - } - if !fieldValue.IsValid() || !fieldValue.CanSet() { - session.Engine.LogWarn("table %v's column %v is not valid or cannot set", - table.Name, key) - continue - } + for key, data := range objMap { + key = strings.ToLower(key) + if _, ok := table.Columns[key]; !ok { + session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) + continue + } + col := table.Columns[key] + fieldName := col.FieldName + fieldPath := strings.Split(fieldName, ".") + var fieldValue reflect.Value + if len(fieldPath) > 2 { + session.Engine.LogError("Unsupported mutliderive", fieldName) + continue + } else if len(fieldPath) == 2 { + parentField := dataStruct.FieldByName(fieldPath[0]) + if parentField.IsValid() { + fieldValue = parentField.FieldByName(fieldPath[1]) + } + } else { + fieldValue = dataStruct.FieldByName(fieldName) + } + if !fieldValue.IsValid() || !fieldValue.CanSet() { + session.Engine.LogWarn("table %v's column %v is not valid or cannot set", + table.Name, key) + continue + } - err := session.bytes2Value(col, &fieldValue, data) - if err != nil { - return err - } - } + err := session.bytes2Value(col, &fieldValue, data) + if err != nil { + return err + } + } - return nil + return nil } //Execute sql func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(sql) - if err != nil { - return nil, err - } - defer rs.Close() + rs, err := session.Db.Prepare(sql) + if err != nil { + return nil, err + } + defer rs.Close() - res, err := rs.Exec(args...) - if err != nil { - return nil, err - } - return res, nil + res, err := rs.Exec(args...) + if err != nil { + return nil, err + } + return res, nil } func (session *Session) exec(sql string, args ...interface{}) (sql.Result, error) { - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } - session.Engine.LogSQL(sql) - session.Engine.LogSQL(args) + session.Engine.LogSQL(sql) + session.Engine.LogSQL(args) - if session.IsAutoCommit { - return session.innerExec(sql, args...) - } - return session.Tx.Exec(sql, args...) + if session.IsAutoCommit { + return session.innerExec(sql, args...) + } + return session.Tx.Exec(sql, args...) } // Exec raw sql func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { - err := session.newDb() - if err != nil { - return nil, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return nil, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - return session.exec(sql, args...) + return session.exec(sql, args...) } // this function create a table according a bean func (session *Session) CreateTable(bean interface{}) error { - session.Statement.RefTable = session.Engine.autoMap(bean) + session.Statement.RefTable = session.Engine.autoMap(bean) - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - return session.createOneTable() + return session.createOneTable() } // create indexes func (session *Session) CreateIndexes(bean interface{}) error { - session.Statement.RefTable = session.Engine.autoMap(bean) + session.Statement.RefTable = session.Engine.autoMap(bean) - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - sqls := session.Statement.genIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) - if err != nil { - return err - } - } - return nil + sqls := session.Statement.genIndexSQL() + for _, sql := range sqls { + _, err = session.exec(sql) + if err != nil { + return err + } + } + return nil } // create uniques func (session *Session) CreateUniques(bean interface{}) error { - session.Statement.RefTable = session.Engine.autoMap(bean) + session.Statement.RefTable = session.Engine.autoMap(bean) - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - sqls := session.Statement.genUniqueSQL() - for _, sql := range sqls { - _, err = session.exec(sql) - if err != nil { - return err - } - } - return nil + sqls := session.Statement.genUniqueSQL() + for _, sql := range sqls { + _, err = session.exec(sql) + if err != nil { + return err + } + } + return nil } func (session *Session) createOneTable() error { - sql := session.Statement.genCreateTableSQL() - session.Engine.LogDebug("create table sql: [", sql, "]") - _, err := session.exec(sql) - return err + sql := session.Statement.genCreateTableSQL() + session.Engine.LogDebug("create table sql: [", sql, "]") + _, err := session.exec(sql) + return err } // to be deleted func (session *Session) createAll() error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - for _, table := range session.Engine.Tables { - session.Statement.RefTable = table - err := session.createOneTable() - if err != nil { - return err - } - } - return nil + for _, table := range session.Engine.Tables { + session.Statement.RefTable = table + err := session.createOneTable() + if err != nil { + return err + } + } + return nil } // drop indexes func (session *Session) DropIndexes(bean interface{}) error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - sqls := session.Statement.genDelIndexSQL() - for _, sql := range sqls { - _, err = session.exec(sql) - if err != nil { - return err - } - } - return nil + sqls := session.Statement.genDelIndexSQL() + for _, sql := range sqls { + _, err = session.exec(sql) + if err != nil { + return err + } + } + return nil } // DropTable drop a table and all indexes of the table func (session *Session) DropTable(bean interface{}) error { - err := session.newDb() - if err != nil { - return err - } + err := session.newDb() + if err != nil { + return err + } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - t := reflect.Indirect(reflect.ValueOf(bean)).Type() - defer session.Statement.Init() - if t.Kind() == reflect.String { - session.Statement.AltTableName = bean.(string) - } else if t.Kind() == reflect.Struct { - session.Statement.RefTable = session.Engine.autoMap(bean) - } else { - return errors.New("Unsupported type") - } + t := reflect.Indirect(reflect.ValueOf(bean)).Type() + defer session.Statement.Init() + if t.Kind() == reflect.String { + session.Statement.AltTableName = bean.(string) + } else if t.Kind() == reflect.Struct { + session.Statement.RefTable = session.Engine.autoMap(bean) + } else { + return errors.New("Unsupported type") + } - sql := session.Statement.genDropSQL() - _, err = session.exec(sql) - return err + sql := session.Statement.genDropSQL() + _, err = session.exec(sql) + return err } func (statement *Statement) convertIdSql(sql string) string { - if statement.RefTable != nil { - col := statement.RefTable.PKColumn() - if col != nil { - sqls := splitNNoCase(sql, "from", 2) - if len(sqls) != 2 { - return "" - } - newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), - statement.Engine.Quote(col.Name), sqls[1]) - return newsql - } - } - return "" + if statement.RefTable != nil { + col := statement.RefTable.PKColumn() + if col != nil { + sqls := splitNNoCase(sql, "from", 2) + if len(sqls) != 2 { + return "" + } + newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), + statement.Engine.Quote(col.Name), sqls[1]) + return newsql + } + } + return "" } func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { - return false, ErrCacheFailed - } - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } - newsql := session.Statement.convertIdSql(sql) - if newsql == "" { - return false, ErrCacheFailed - } + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return false, ErrCacheFailed + } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + newsql := session.Statement.convertIdSql(sql) + if newsql == "" { + return false, ErrCacheFailed + } - cacher := session.Statement.RefTable.Cacher - tableName := session.Statement.TableName() - session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args) - ids, err := getCacheSql(cacher, tableName, newsql, args) - if err != nil { - resultsSlice, err := session.query(newsql, args...) - if err != nil { - return false, err - } - session.Engine.LogDebug("[xorm:cacheGet] query ids:", resultsSlice) - ids = make([]int64, 0) - if len(resultsSlice) > 0 { - data := resultsSlice[0] - var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { - return false, ErrCacheFailed - } else { - id, err = strconv.ParseInt(string(v), 10, 64) - if err != nil { - return false, err - } - } - ids = append(ids, id) - } - session.Engine.LogDebug("[xorm:cacheGet] cache ids:", newsql, ids) - err = putCacheSql(cacher, ids, tableName, newsql, args) - if err != nil { - return false, err - } - } else { - session.Engine.LogDebug("[xorm:cacheGet] cached sql:", newsql) - } + cacher := session.Statement.RefTable.Cacher + tableName := session.Statement.TableName() + session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args) + ids, err := getCacheSql(cacher, tableName, newsql, args) + if err != nil { + resultsSlice, err := session.query(newsql, args...) + if err != nil { + return false, err + } + session.Engine.LogDebug("[xorm:cacheGet] query ids:", resultsSlice) + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + data := resultsSlice[0] + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return false, ErrCacheFailed + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return false, err + } + } + ids = append(ids, id) + } + session.Engine.LogDebug("[xorm:cacheGet] cache ids:", newsql, ids) + err = putCacheSql(cacher, ids, tableName, newsql, args) + if err != nil { + return false, err + } + } else { + session.Engine.LogDebug("[xorm:cacheGet] cached sql:", newsql) + } - if len(ids) > 0 { - structValue := reflect.Indirect(reflect.ValueOf(bean)) - id := ids[0] - session.Engine.LogDebug("[xorm:cacheGet] get bean:", tableName, id) - cacheBean := cacher.GetBean(tableName, id) - if cacheBean == nil { - newSession := session.Engine.NewSession() - defer newSession.Close() - cacheBean = reflect.New(structValue.Type()).Interface() - if session.Statement.AltTableName != "" { - has, err = newSession.Id(id).NoCache().Table(session.Statement.AltTableName).Get(cacheBean) - } else { - has, err = newSession.Id(id).NoCache().Get(cacheBean) - } - if err != nil || !has { - return has, err - } + if len(ids) > 0 { + structValue := reflect.Indirect(reflect.ValueOf(bean)) + id := ids[0] + session.Engine.LogDebug("[xorm:cacheGet] get bean:", tableName, id) + cacheBean := cacher.GetBean(tableName, id) + if cacheBean == nil { + newSession := session.Engine.NewSession() + defer newSession.Close() + cacheBean = reflect.New(structValue.Type()).Interface() + if session.Statement.AltTableName != "" { + has, err = newSession.Id(id).NoCache().Table(session.Statement.AltTableName).Get(cacheBean) + } else { + has, err = newSession.Id(id).NoCache().Get(cacheBean) + } + if err != nil || !has { + return has, err + } - session.Engine.LogDebug("[xorm:cacheGet] cache bean:", tableName, id, cacheBean) - cacher.PutBean(tableName, id, cacheBean) - } else { - session.Engine.LogDebug("[xorm:cacheGet] cached bean:", tableName, id, cacheBean) - has = true - } - structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) + session.Engine.LogDebug("[xorm:cacheGet] cache bean:", tableName, id, cacheBean) + cacher.PutBean(tableName, id, cacheBean) + } else { + session.Engine.LogDebug("[xorm:cacheGet] cached bean:", tableName, id, cacheBean) + has = true + } + structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) - return has, nil - } - return false, nil + return has, nil + } + return false, nil } func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { - if session.Statement.RefTable == nil || - session.Statement.RefTable.PrimaryKey == "" || - indexNoCase(sql, "having") != -1 || - indexNoCase(sql, "group by") != -1 { - return ErrCacheFailed - } + if session.Statement.RefTable == nil || + session.Statement.RefTable.PrimaryKey == "" || + indexNoCase(sql, "having") != -1 || + indexNoCase(sql, "group by") != -1 { + return ErrCacheFailed + } - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } - newsql := session.Statement.convertIdSql(sql) - if newsql == "" { - return ErrCacheFailed - } + newsql := session.Statement.convertIdSql(sql) + if newsql == "" { + return ErrCacheFailed + } - table := session.Statement.RefTable - cacher := table.Cacher - ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) - if err != nil { - //session.Engine.LogError(err) - resultsSlice, err := session.query(newsql, args...) - if err != nil { - return err - } - // 查询数目太大,采用缓存将不是一个很好的方式。 - if len(resultsSlice) > 500 { - session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 500, no cache", len(resultsSlice)) - return ErrCacheFailed - } + table := session.Statement.RefTable + cacher := table.Cacher + ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) + if err != nil { + //session.Engine.LogError(err) + resultsSlice, err := session.query(newsql, args...) + if err != nil { + return err + } + // 查询数目太大,采用缓存将不是一个很好的方式。 + if len(resultsSlice) > 500 { + session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 500, no cache", len(resultsSlice)) + return ErrCacheFailed + } - tableName := session.Statement.TableName() - ids = make([]int64, 0) - if len(resultsSlice) > 0 { - for _, data := range resultsSlice { - //fmt.Println(data) - var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { - return errors.New("no id") - } else { - id, err = strconv.ParseInt(string(v), 10, 64) - if err != nil { - return err - } - } - ids = append(ids, id) - } - } - session.Engine.LogDebug("[xorm:cacheFind] cache ids:", ids, tableName, newsql, args) - err = putCacheSql(cacher, ids, tableName, newsql, args) - if err != nil { - return err - } - } else { - session.Engine.LogDebug("[xorm:cacheFind] cached sql:", newsql, args) - } + tableName := session.Statement.TableName() + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + for _, data := range resultsSlice { + //fmt.Println(data) + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + } + ids = append(ids, id) + } + } + session.Engine.LogDebug("[xorm:cacheFind] cache ids:", ids, tableName, newsql, args) + err = putCacheSql(cacher, ids, tableName, newsql, args) + if err != nil { + return err + } + } else { + session.Engine.LogDebug("[xorm:cacheFind] cached sql:", newsql, args) + } - sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - pkFieldName := session.Statement.RefTable.PKColumn().FieldName + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + pkFieldName := session.Statement.RefTable.PKColumn().FieldName - ididxes := make(map[int64]int) - var ides []interface{} = make([]interface{}, 0) - var temps []interface{} = make([]interface{}, len(ids)) - tableName := session.Statement.TableName() - for idx, id := range ids { - bean := cacher.GetBean(tableName, id) - if bean == nil { - ides = append(ides, id) - ididxes[id] = idx - } else { - session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) + ididxes := make(map[int64]int) + var ides []interface{} = make([]interface{}, 0) + var temps []interface{} = make([]interface{}, len(ids)) + tableName := session.Statement.TableName() + for idx, id := range ids { + bean := cacher.GetBean(tableName, id) + if bean == nil { + ides = append(ides, id) + ididxes[id] = idx + } else { + session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) - sid := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() - if sid != id { - session.Engine.LogError("[xorm:cacheFind] error cache", id, sid, bean) - return ErrCacheFailed - } - temps[idx] = bean - } - } + sid := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() + if sid != id { + session.Engine.LogError("[xorm:cacheFind] error cache", id, sid, bean) + return ErrCacheFailed + } + temps[idx] = bean + } + } - if len(ides) > 0 { - newSession := session.Engine.NewSession() - defer newSession.Close() + if len(ides) > 0 { + newSession := session.Engine.NewSession() + defer newSession.Close() - slices := reflect.New(reflect.SliceOf(t)) - beans := slices.Interface() - //beans := reflect.New(sliceValue.Type()).Interface() - //err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) - err = newSession.In("(id)", ides...).NoCache().Find(beans) - if err != nil { - return err - } + slices := reflect.New(reflect.SliceOf(t)) + beans := slices.Interface() + //beans := reflect.New(sliceValue.Type()).Interface() + //err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) + err = newSession.In("(id)", ides...).NoCache().Find(beans) + if err != nil { + return err + } - vs := reflect.Indirect(reflect.ValueOf(beans)) - for i := 0; i < vs.Len(); i++ { - rv := vs.Index(i) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - bean := rv.Interface() - id := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() - //bean := vs.Index(i).Addr().Interface() - temps[ididxes[id]] = bean - //temps[idxes[i]] = bean - session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, id, bean) - cacher.PutBean(tableName, id, bean) - } - } + vs := reflect.Indirect(reflect.ValueOf(beans)) + for i := 0; i < vs.Len(); i++ { + rv := vs.Index(i) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() + } + bean := rv.Interface() + id := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() + //bean := vs.Index(i).Addr().Interface() + temps[ididxes[id]] = bean + //temps[idxes[i]] = bean + session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, id, bean) + cacher.PutBean(tableName, id, bean) + } + } - for j := 0; j < len(temps); j++ { - bean := temps[j] - if bean == nil { - session.Engine.LogError("[xorm:cacheFind] cache error:", tableName, ides[j], bean) - return errors.New("cache error") - } - if sliceValue.Kind() == reflect.Slice { - if t.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean))) - } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) - } - } else if sliceValue.Kind() == reflect.Map { - var key int64 - if table.PrimaryKey != "" { - key = ids[j] - } else { - key = int64(j) - } - if t.Kind() == reflect.Ptr { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean)) - } else { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) - } - } - /*} else { - session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j]) - cacher.DelBean(tableName, ids[j]) + for j := 0; j < len(temps); j++ { + bean := temps[j] + if bean == nil { + session.Engine.LogError("[xorm:cacheFind] cache error:", tableName, ides[j], bean) + return errors.New("cache error") + } + if sliceValue.Kind() == reflect.Slice { + if t.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) + } + } else if sliceValue.Kind() == reflect.Map { + var key int64 + if table.PrimaryKey != "" { + key = ids[j] + } else { + key = int64(j) + } + if t.Kind() == reflect.Ptr { + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean)) + } else { + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) + } + } + /*} else { + session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j]) + cacher.DelBean(tableName, ids[j]) - session.Engine.LogDebug("[xorm:cacheFind] cache clear:", tableName) - cacher.ClearIds(tableName) - }*/ - } + session.Engine.LogDebug("[xorm:cacheFind] cache clear:", tableName) + cacher.ClearIds(tableName) + }*/ + } - return nil + return nil } // IterFunc only use by Iterate @@ -812,1630 +812,1629 @@ type IterFunc func(idx int, bean interface{}) error // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (session *Session) Iterate(bean interface{}, fun IterFunc) error { - err := session.newDb() - if err != nil { - return err - } + err := session.newDb() + if err != nil { + return err + } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - var sql string - var args []interface{} - session.Statement.RefTable = session.Engine.autoMap(bean) - if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) - } else { - sql = session.Statement.RawSQL - args = session.Statement.RawParams - } + var sql string + var args []interface{} + session.Statement.RefTable = session.Engine.autoMap(bean) + if session.Statement.RawSQL == "" { + sql, args = session.Statement.genGetSql(bean) + } else { + sql = session.Statement.RawSQL + args = session.Statement.RawParams + } - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } - session.Engine.LogSQL(sql) - session.Engine.LogSQL(args) + session.Engine.LogSQL(sql) + session.Engine.LogSQL(args) - s, err := session.Db.Prepare(sql) - if err != nil { - return err - } - defer s.Close() - rows, err := s.Query(args...) - if err != nil { - return err - } - defer rows.Close() + s, err := session.Db.Prepare(sql) + if err != nil { + return err + } + defer s.Close() + rows, err := s.Query(args...) + if err != nil { + return err + } + defer rows.Close() - fields, err := rows.Columns() - if err != nil { - return err - } - t := reflect.Indirect(reflect.ValueOf(bean)).Type() - b := reflect.New(t).Interface() - i := 0 - for rows.Next() { - result, err := row2map(rows, fields) - if err == nil { - err = session.scanMapIntoStruct(b, result) - } - if err == nil { - err = fun(i, b) - i = i + 1 - } - if err != nil { - return err - } - } + fields, err := rows.Columns() + if err != nil { + return err + } + t := reflect.Indirect(reflect.ValueOf(bean)).Type() + b := reflect.New(t).Interface() + i := 0 + for rows.Next() { + result, err := row2map(rows, fields) + if err == nil { + err = session.scanMapIntoStruct(b, result) + } + if err == nil { + err = fun(i, b) + i = i + 1 + } + if err != nil { + return err + } + } - return nil + return nil } // get retrieve one record from database, bean's non-empty fields // will be as conditions func (session *Session) Get(bean interface{}) (bool, error) { - err := session.newDb() - if err != nil { - return false, err - } + err := session.newDb() + if err != nil { + return false, err + } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - session.Statement.Limit(1) - var sql string - var args []interface{} - session.Statement.RefTable = session.Engine.autoMap(bean) + session.Statement.Limit(1) + var sql string + var args []interface{} + session.Statement.RefTable = session.Engine.autoMap(bean) - if session.Statement.RawSQL == "" { - sql, args = session.Statement.genGetSql(bean) - } else { - sql = session.Statement.RawSQL - args = session.Statement.RawParams - } + if session.Statement.RawSQL == "" { + sql, args = session.Statement.genGetSql(bean) + } else { + sql = session.Statement.RawSQL + args = session.Statement.RawParams + } - if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { - has, err := session.cacheGet(bean, sql, args...) - if err != ErrCacheFailed { - return has, err - } - } + if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { + has, err := session.cacheGet(bean, sql, args...) + if err != ErrCacheFailed { + return has, err + } + } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return false, err - } - if len(resultsSlice) < 1 { - return false, nil - } + resultsSlice, err := session.query(sql, args...) + if err != nil { + return false, err + } + if len(resultsSlice) < 1 { + return false, nil + } - err = session.scanMapIntoStruct(bean, resultsSlice[0]) - if err != nil { - return true, err - } - if len(resultsSlice) == 1 { - return true, nil - } else { - return true, errors.New("More than one record") - } + err = session.scanMapIntoStruct(bean, resultsSlice[0]) + if err != nil { + return true, err + } + if len(resultsSlice) == 1 { + return true, nil + } else { + return true, errors.New("More than one record") + } } // Count counts the records. bean's non-empty fields // are conditions. func (session *Session) Count(bean interface{}) (int64, error) { - err := session.newDb() - if err != nil { - return 0, err - } + err := session.newDb() + if err != nil { + return 0, err + } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - var sql string - var args []interface{} - if session.Statement.RawSQL == "" { - sql, args = session.Statement.genCountSql(bean) - } else { - sql = session.Statement.RawSQL - args = session.Statement.RawParams - } + var sql string + var args []interface{} + if session.Statement.RawSQL == "" { + sql, args = session.Statement.genCountSql(bean) + } else { + sql = session.Statement.RawSQL + args = session.Statement.RawParams + } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return 0, err - } + resultsSlice, err := session.query(sql, args...) + if err != nil { + return 0, err + } - var total int64 = 0 - if len(resultsSlice) > 0 { - results := resultsSlice[0] - for _, value := range results { - total, err = strconv.ParseInt(string(value), 10, 64) - break - } - } + var total int64 = 0 + if len(resultsSlice) > 0 { + results := resultsSlice[0] + for _, value := range results { + total, err = strconv.ParseInt(string(value), 10, 64) + break + } + } - return int64(total), err + return int64(total), err } // Find retrieve records from table, condiBeans's non-empty fields // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { - return errors.New("needs a pointer to a slice or a map") - } + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + return errors.New("needs a pointer to a slice or a map") + } - sliceElementType := sliceValue.Type().Elem() - var table *Table - if session.Statement.RefTable == nil { - if sliceElementType.Kind() == reflect.Ptr { - if sliceElementType.Elem().Kind() == reflect.Struct { - table = session.Engine.autoMapType(sliceElementType.Elem()) - } else { - return errors.New("slice type") - } - } else if sliceElementType.Kind() == reflect.Struct { - table = session.Engine.autoMapType(sliceElementType) - } else { - return errors.New("slice type") - } - session.Statement.RefTable = table - } else { - table = session.Statement.RefTable - } + sliceElementType := sliceValue.Type().Elem() + var table *Table + if session.Statement.RefTable == nil { + if sliceElementType.Kind() == reflect.Ptr { + if sliceElementType.Elem().Kind() == reflect.Struct { + table = session.Engine.autoMapType(sliceElementType.Elem()) + } else { + return errors.New("slice type") + } + } else if sliceElementType.Kind() == reflect.Struct { + table = session.Engine.autoMapType(sliceElementType) + } else { + return errors.New("slice type") + } + session.Statement.RefTable = table + } else { + table = session.Statement.RefTable + } - if len(condiBean) > 0 { - colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) - session.Statement.ConditionStr = strings.Join(colNames, " AND ") - session.Statement.BeanArgs = args - } + if len(condiBean) > 0 { + colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, + false, session.Statement.allUseBool, session.Statement.boolColumnMap) + session.Statement.ConditionStr = strings.Join(colNames, " AND ") + session.Statement.BeanArgs = args + } - var sql string - var args []interface{} - if session.Statement.RawSQL == "" { - var columnStr string = session.Statement.ColumnStr - if columnStr == "" { - columnStr = session.Statement.genColumnStr() - } + var sql string + var args []interface{} + if session.Statement.RawSQL == "" { + var columnStr string = session.Statement.ColumnStr + if columnStr == "" { + columnStr = session.Statement.genColumnStr() + } - session.Statement.attachInSql() + session.Statement.attachInSql() - sql = session.Statement.genSelectSql(columnStr) - args = append(session.Statement.Params, session.Statement.BeanArgs...) - } else { - sql = session.Statement.RawSQL - args = session.Statement.RawParams - } + sql = session.Statement.genSelectSql(columnStr) + args = append(session.Statement.Params, session.Statement.BeanArgs...) + } else { + sql = session.Statement.RawSQL + args = session.Statement.RawParams + } - if table.Cacher != nil && - session.Statement.UseCache && - !session.Statement.IsDistinct { - err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) - if err != ErrCacheFailed { - return err - } - session.Engine.LogWarn("Cache Find Failed") - } + if table.Cacher != nil && + session.Statement.UseCache && + !session.Statement.IsDistinct { + err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) + if err != ErrCacheFailed { + return err + } + session.Engine.LogWarn("Cache Find Failed") + } - resultsSlice, err := session.query(sql, args...) - if err != nil { - return err - } + resultsSlice, err := session.query(sql, args...) + if err != nil { + return err + } - for i, results := range resultsSlice { - var newValue reflect.Value - if sliceElementType.Kind() == reflect.Ptr { - newValue = reflect.New(sliceElementType.Elem()) - } else { - newValue = reflect.New(sliceElementType) - } - err := session.scanMapIntoStruct(newValue.Interface(), results) - if err != nil { - return err - } - if sliceValue.Kind() == reflect.Slice { - if sliceElementType.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) - } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) - } - } else if sliceValue.Kind() == reflect.Map { - var key int64 - if table.PrimaryKey != "" { - x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) - if err != nil { - return errors.New("pk " + table.PrimaryKey + " as int64: " + err.Error()) - } - key = x - } else { - key = int64(i) - } - if sliceElementType.Kind() == reflect.Ptr { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface())) - } else { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface()))) - } - } - } - return nil + for i, results := range resultsSlice { + var newValue reflect.Value + if sliceElementType.Kind() == reflect.Ptr { + newValue = reflect.New(sliceElementType.Elem()) + } else { + newValue = reflect.New(sliceElementType) + } + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err + } + if sliceValue.Kind() == reflect.Slice { + if sliceElementType.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + } else if sliceValue.Kind() == reflect.Map { + var key int64 + if table.PrimaryKey != "" { + x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) + if err != nil { + return errors.New("pk " + table.PrimaryKey + " as int64: " + err.Error()) + } + key = x + } else { + key = int64(i) + } + if sliceElementType.Kind() == reflect.Ptr { + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface())) + } else { + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface()))) + } + } + } + return nil } // Test if database is ok func (session *Session) Ping() error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - return session.Db.Ping() + return session.Db.Ping() } func (session *Session) isColumnExist(tableName, colName string) (bool, error) { - err := session.newDb() - if err != nil { - return false, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) - results, err := session.query(sql, args...) - return len(results) > 0, err + err := session.newDb() + if err != nil { + return false, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) + results, err := session.query(sql, args...) + return len(results) > 0, err } func (session *Session) isTableExist(tableName string) (bool, error) { - err := session.newDb() - if err != nil { - return false, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - sql, args := session.Engine.dialect.TableCheckSql(tableName) - results, err := session.query(sql, args...) - return len(results) > 0, err + err := session.newDb() + if err != nil { + return false, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + sql, args := session.Engine.dialect.TableCheckSql(tableName) + results, err := session.query(sql, args...) + return len(results) > 0, err } func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bool, error) { - err := session.newDb() - if err != nil { - return false, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - var idx string - if unique { - idx = uniqueName(tableName, idxName) - } else { - idx = indexName(tableName, idxName) - } - sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) - results, err := session.query(sql, args...) - return len(results) > 0, err + err := session.newDb() + if err != nil { + return false, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + var idx string + if unique { + idx = uniqueName(tableName, idxName) + } else { + idx = indexName(tableName, idxName) + } + sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) + results, err := session.query(sql, args...) + return len(results) > 0, err } // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { - indexes, err := session.Engine.dialect.GetIndexes(tableName) - if err != nil { - return false, err - } + indexes, err := session.Engine.dialect.GetIndexes(tableName) + if err != nil { + return false, err + } - for _, index := range indexes { - //fmt.Println(i, "new:", cols, "-old:", index.Cols) - if sliceEq(index.Cols, cols) { - if unique { - return index.Type == UniqueType, nil - } else { - return index.Type == IndexType, nil - } - } - } - return false, nil + for _, index := range indexes { + //fmt.Println(i, "new:", cols, "-old:", index.Cols) + if sliceEq(index.Cols, cols) { + if unique { + return index.Type == UniqueType, nil + } else { + return index.Type == IndexType, nil + } + } + } + return false, nil } func (session *Session) addColumn(colName string) error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - //fmt.Println(session.Statement.RefTable) - col := session.Statement.RefTable.Columns[colName] - sql, args := session.Statement.genAddColumnStr(col) - _, err = session.exec(sql, args...) - return err + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + //fmt.Println(session.Statement.RefTable) + col := session.Statement.RefTable.Columns[colName] + sql, args := session.Statement.genAddColumnStr(col) + _, err = session.exec(sql, args...) + return err } func (session *Session) addIndex(tableName, idxName string) error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - //fmt.Println(idxName) - cols := session.Statement.RefTable.Indexes[idxName].Cols - sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) - _, err = session.exec(sql, args...) - return err + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + //fmt.Println(idxName) + cols := session.Statement.RefTable.Indexes[idxName].Cols + sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) + _, err = session.exec(sql, args...) + return err } func (session *Session) addUnique(tableName, uqeName string) error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } - //fmt.Println(uqeName, session.Statement.RefTable.Uniques) - cols := session.Statement.RefTable.Indexes[uqeName].Cols - sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) - _, err = session.exec(sql, args...) - return err + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + //fmt.Println(uqeName, session.Statement.RefTable.Uniques) + cols := session.Statement.RefTable.Indexes[uqeName].Cols + sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) + _, err = session.exec(sql, args...) + return err } // To be deleted func (session *Session) dropAll() error { - err := session.newDb() - if err != nil { - return err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - for _, table := range session.Engine.Tables { - session.Statement.Init() - session.Statement.RefTable = table - sql := session.Statement.genDropSQL() - _, err := session.exec(sql) - if err != nil { - return err - } - } - return nil + for _, table := range session.Engine.Tables { + session.Statement.Init() + session.Statement.RefTable = table + sql := session.Statement.genDropSQL() + _, err := session.exec(sql) + if err != nil { + return err + } + } + return nil } func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err error) { - result := make(map[string][]byte) - var scanResultContainers []interface{} - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers = append(scanResultContainers, &scanResultContainer) - } - if err := rows.Scan(scanResultContainers...); err != nil { - return nil, err - } - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + result := make(map[string][]byte) + var scanResultContainers []interface{} + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := rows.Scan(scanResultContainers...); err != nil { + return nil, err + } + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore - if rawValue.Interface() == nil { - //fmt.Println("ignore ...", key, rawValue) - continue - } - aa := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - var str string - switch aa.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - result[key] = []byte(str) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - result[key] = []byte(str) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - result[key] = []byte(str) - case reflect.String: - str = vv.String() - result[key] = []byte(str) - case reflect.Array, reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - result[key] = rawValue.Interface().([]byte) - str = string(result[key]) - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - //时间类型 - case reflect.Struct: - if aa.String() == "time.Time" { - str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) - result[key] = []byte(str) - } else { - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - case reflect.Bool: - str = strconv.FormatBool(vv.Bool()) - result[key] = []byte(str) - case reflect.Complex128, reflect.Complex64: - str = fmt.Sprintf("%v", vv.Complex()) - result[key] = []byte(str) - /* TODO: unsupported types below - case reflect.Map: - case reflect.Ptr: - case reflect.Uintptr: - case reflect.UnsafePointer: - case reflect.Chan, reflect.Func, reflect.Interface: - */ - default: - return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) - } - } - return result, nil + //if row is null then ignore + if rawValue.Interface() == nil { + //fmt.Println("ignore ...", key, rawValue) + continue + } + aa := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + var str string + switch aa.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + result[key] = []byte(str) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + result[key] = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + result[key] = []byte(str) + case reflect.String: + str = vv.String() + result[key] = []byte(str) + case reflect.Array, reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + result[key] = rawValue.Interface().([]byte) + str = string(result[key]) + default: + return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) + } + //时间类型 + case reflect.Struct: + if aa.String() == "time.Time" { + str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) + result[key] = []byte(str) + } else { + return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) + } + case reflect.Bool: + str = strconv.FormatBool(vv.Bool()) + result[key] = []byte(str) + case reflect.Complex128, reflect.Complex64: + str = fmt.Sprintf("%v", vv.Complex()) + result[key] = []byte(str) + /* TODO: unsupported types below + case reflect.Map: + case reflect.Ptr: + case reflect.Uintptr: + case reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.Interface: + */ + default: + return nil, errors.New(fmt.Sprintf("Unsupported struct type %v", vv.Type().Name())) + } + } + return result, nil } func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { - fields, err := rows.Columns() - if err != nil { - return nil, err - } - for rows.Next() { - result, err := row2map(rows, fields) - if err != nil { - return nil, err - } - resultsSlice = append(resultsSlice, result) - } + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2map(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } - return resultsSlice, nil + return resultsSlice, nil } func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } - session.Engine.LogSQL(sql) - session.Engine.LogSQL(paramStr) + session.Engine.LogSQL(sql) + session.Engine.LogSQL(paramStr) - if session.IsAutoCommit { - return query(session.Db, sql, paramStr...) - } - return txQuery(session.Tx, sql, paramStr...) + if session.IsAutoCommit { + return query(session.Db, sql, paramStr...) + } + return txQuery(session.Tx, sql, paramStr...) } func txQuery(tx *sql.Tx, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - rows, err := tx.Query(sql, params...) - if err != nil { - return nil, err - } - defer rows.Close() + rows, err := tx.Query(sql, params...) + if err != nil { + return nil, err + } + defer rows.Close() - return rows2maps(rows) + return rows2maps(rows) } func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - s, err := db.Prepare(sql) - if err != nil { - return nil, err - } - defer s.Close() - rows, err := s.Query(params...) - if err != nil { - return nil, err - } - defer rows.Close() + s, err := db.Prepare(sql) + if err != nil { + return nil, err + } + defer s.Close() + rows, err := s.Query(params...) + if err != nil { + return nil, err + } + defer rows.Close() - return rows2maps(rows) + return rows2maps(rows) } // Exec a raw sql and return records as []map[string][]byte func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - err = session.newDb() - if err != nil { - return nil, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err = session.newDb() + if err != nil { + return nil, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - return session.query(sql, paramStr...) + return session.query(sql, paramStr...) } // insert one or more beans func (session *Session) Insert(beans ...interface{}) (int64, error) { - var affected int64 = 0 - var err error = nil - err = session.newDb() - if err != nil { - return 0, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + var affected int64 = 0 + var err error = nil + err = session.newDb() + if err != nil { + return 0, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - for _, bean := range beans { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - if sliceValue.Kind() == reflect.Slice { - if session.Engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - size := sliceValue.Len() - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } - } - } else { - cnt, err := session.innerInsert(bean) - if err != nil { - return affected, err - } - affected += cnt - } - } + for _, bean := range beans { + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + if sliceValue.Kind() == reflect.Slice { + if session.Engine.SupportInsertMany() { + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err + } + affected += cnt + } else { + size := sliceValue.Len() + for i := 0; i < size; i++ { + cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) + if err != nil { + return affected, err + } + affected += cnt + } + } + } else { + cnt, err := session.innerInsert(bean) + if err != nil { + return affected, err + } + affected += cnt + } + } - return affected, err + return affected, err } func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) { - sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - if sliceValue.Kind() != reflect.Slice { - return 0, errors.New("needs a pointer to a slice") - } + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice { + return 0, errors.New("needs a pointer to a slice") + } - bean := sliceValue.Index(0).Interface() - sliceElementType := rType(bean) + bean := sliceValue.Index(0).Interface() + sliceElementType := rType(bean) - table := session.Engine.autoMapType(sliceElementType) - session.Statement.RefTable = table + table := session.Engine.autoMapType(sliceElementType) + session.Statement.RefTable = table - size := sliceValue.Len() + size := sliceValue.Len() - colNames := make([]string, 0) - colMultiPlaces := make([]string, 0) - var args = make([]interface{}, 0) - cols := make([]*Column, 0) + colNames := make([]string, 0) + colMultiPlaces := make([]string, 0) + var args = make([]interface{}, 0) + cols := make([]*Column, 0) - for i := 0; i < size; i++ { - elemValue := sliceValue.Index(i).Interface() - colPlaces := make([]string, 0) + for i := 0; i < size; i++ { + elemValue := sliceValue.Index(i).Interface() + colPlaces := make([]string, 0) - // handle BeforeInsertProcessor - // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? - for _, closure := range session.beforeClosures { - closure(elemValue) - } + // handle BeforeInsertProcessor + // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? + for _, closure := range session.beforeClosures { + closure(elemValue) + } - if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok { - processor.BeforeInsert() - } - // -- + if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok { + processor.BeforeInsert() + } + // -- - if i == 0 { - for _, col := range table.Columns { - fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue - } - if col.MapType == ONLYFROMDB { - continue - } - if session.Statement.ColumnStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; !ok { - continue - } - } - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - args = append(args, time.Now()) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } + if i == 0 { + for _, col := range table.Columns { + fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) + if col.IsAutoIncrement && fieldValue.Int() == 0 { + continue + } + if col.MapType == ONLYFROMDB { + continue + } + if session.Statement.ColumnStr != "" { + if _, ok := session.Statement.columnMap[col.Name]; !ok { + continue + } + } + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { + args = append(args, time.Now()) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return 0, err + } + args = append(args, arg) + } - colNames = append(colNames, col.Name) - cols = append(cols, col) - colPlaces = append(colPlaces, "?") - } - } else { - for _, col := range cols { - fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue - } - if col.MapType == ONLYFROMDB { - continue - } - if session.Statement.ColumnStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; !ok { - continue - } - } - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - args = append(args, time.Now()) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return 0, err - } - args = append(args, arg) - } + colNames = append(colNames, col.Name) + cols = append(cols, col) + colPlaces = append(colPlaces, "?") + } + } else { + for _, col := range cols { + fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) + if col.IsAutoIncrement && fieldValue.Int() == 0 { + continue + } + if col.MapType == ONLYFROMDB { + continue + } + if session.Statement.ColumnStr != "" { + if _, ok := session.Statement.columnMap[col.Name]; !ok { + continue + } + } + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { + args = append(args, time.Now()) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return 0, err + } + args = append(args, arg) + } - colPlaces = append(colPlaces, "?") - } - } - colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) - } - cleanupProcessorsClosures(&session.beforeClosures) + colPlaces = append(colPlaces, "?") + } + } + colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) + } + cleanupProcessorsClosures(&session.beforeClosures) - statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", - session.Engine.QuoteStr(), - session.Statement.TableName(), - session.Engine.QuoteStr(), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr(), - strings.Join(colMultiPlaces, "),(")) + statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", + session.Engine.QuoteStr(), + session.Statement.TableName(), + session.Engine.QuoteStr(), + session.Engine.QuoteStr(), + strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), + session.Engine.QuoteStr(), + strings.Join(colMultiPlaces, "),(")) - res, err := session.exec(statement, args...) - if err != nil { - return 0, err - } + res, err := session.exec(statement, args...) + if err != nil { + return 0, err + } - if table.Cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) - } + if table.Cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } - lenAfterClosures := len(session.afterClosures) - for i := 0; i < size; i++ { - elemValue := sliceValue.Index(i).Interface() - // handle AfterInsertProcessor - if session.IsAutoCommit { - // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? - for _, closure := range session.afterClosures { - closure(elemValue) - } - if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok { - processor.AfterInsert() - } - } else { - if lenAfterClosures > 0 { - if value, has := session.afterInsertBeans[elemValue]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterInsertBeans[elemValue] = &afterClosures - } + lenAfterClosures := len(session.afterClosures) + for i := 0; i < size; i++ { + elemValue := sliceValue.Index(i).Interface() + // handle AfterInsertProcessor + if session.IsAutoCommit { + // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? + for _, closure := range session.afterClosures { + closure(elemValue) + } + if processor, ok := interface{}(elemValue).(AfterInsertProcessor); ok { + processor.AfterInsert() + } + } else { + if lenAfterClosures > 0 { + if value, has := session.afterInsertBeans[elemValue]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterInsertBeans[elemValue] = &afterClosures + } - } else { - if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok { - session.afterInsertBeans[elemValue] = nil - } - } - } - } - cleanupProcessorsClosures(&session.afterClosures) - return res.RowsAffected() + } else { + if _, ok := interface{}(elemValue).(AfterInsertProcessor); ok { + session.afterInsertBeans[elemValue] = nil + } + } + } + } + cleanupProcessorsClosures(&session.afterClosures) + return res.RowsAffected() } // Insert multiple records func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { - err := session.newDb() - if err != nil { - return 0, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return 0, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - return session.innerInsertMulti(rowsSlicePtr) + return session.innerInsertMulti(rowsSlicePtr) } // convert a db data([]byte) to a field value func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data []byte) error { - if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { - return structConvert.FromDB(data) - } + if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { + return structConvert.FromDB(data) + } - var v interface{} - key := col.Name - fieldType := fieldValue.Type() - + var v interface{} + key := col.Name + fieldType := fieldValue.Type() - //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) - switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: - x := reflect.New(fieldType) + //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) + switch fieldType.Kind() { + case reflect.Complex64, reflect.Complex128: + x := reflect.New(fieldType) - err := json.Unmarshal(data, x.Interface()) - if err != nil { - session.Engine.LogSQL(err) - return err - } - fieldValue.Set(x.Elem()) - case reflect.Slice, reflect.Array, reflect.Map: - v = data - t := fieldType.Elem() - k := t.Kind() - if col.SQLType.IsText() { - x := reflect.New(fieldType) - err := json.Unmarshal(data, x.Interface()) - if err != nil { - session.Engine.LogSQL(err) - return err - } - fieldValue.Set(x.Elem()) - } else if col.SQLType.IsBlob() { - if k == reflect.Uint8 { - fieldValue.Set(reflect.ValueOf(v)) - } else { - x := reflect.New(fieldType) - err := json.Unmarshal(data, x.Interface()) - if err != nil { - session.Engine.LogSQL(err) - return err - } - fieldValue.Set(x.Elem()) - } - } else { - return ErrUnSupportedType - } - case reflect.String: - fieldValue.SetString(string(data)) - case reflect.Bool: - d := string(data) - v, err := strconv.ParseBool(d) - if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(v)) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - //fmt.Println("######", x, data) - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.SetInt(x) - case reflect.Float32, reflect.Float64: - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) - } - fieldValue.SetFloat(x) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.SetUint(x) - //Now only support Time type - case reflect.Struct: - if fieldType.String() == "time.Time" { - sdata := strings.TrimSpace(string(data)) - var x time.Time - var err error + err := json.Unmarshal(data, x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + case reflect.Slice, reflect.Array, reflect.Map: + v = data + t := fieldType.Elem() + k := t.Kind() + if col.SQLType.IsText() { + x := reflect.New(fieldType) + err := json.Unmarshal(data, x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } else if col.SQLType.IsBlob() { + if k == reflect.Uint8 { + fieldValue.Set(reflect.ValueOf(v)) + } else { + x := reflect.New(fieldType) + err := json.Unmarshal(data, x.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(x.Elem()) + } + } else { + return ErrUnSupportedType + } + case reflect.String: + fieldValue.SetString(string(data)) + case reflect.Bool: + d := string(data) + v, err := strconv.ParseBool(d) + if err != nil { + return errors.New("arg " + key + " as bool: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(v)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + sdata := string(data) + var x int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int64(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x, err = strconv.ParseInt(sdata, 16, 64) + } else if strings.HasPrefix(sdata, "0") { + x, err = strconv.ParseInt(sdata, 8, 64) + } else { + x, err = strconv.ParseInt(sdata, 10, 64) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.SetInt(x) + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + fieldValue.SetFloat(x) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.SetUint(x) + //Now only support Time type + case reflect.Struct: + if fieldType.String() == "time.Time" { + sdata := strings.TrimSpace(string(data)) + var x time.Time + var err error - if sdata == "0000-00-00 00:00:00" || - sdata == "0001-01-01 00:00:00" { - } else if !strings.ContainsAny(sdata, "- :") { - // time stamp - sd, err := strconv.ParseInt(sdata, 10, 64) - if err == nil { - x = time.Unix(0, sd) - } - } else if len(sdata) > 19 { - x, err = time.Parse(time.RFC3339Nano, sdata) - if err != nil { - x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) - } - } else if len(sdata) == 19 { - x, err = time.Parse("2006-01-02 15:04:05", sdata) - } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { - x, err = time.Parse("2006-01-02", sdata) - } else if col.SQLType.Name == Time { - if len(sdata) > 8 { - sdata = sdata[len(sdata)-8:] - } - st := fmt.Sprintf("2006-01-02 %v", sdata) - x, err = time.Parse("2006-01-02 15:04:05", st) - } else { - return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) - } - if err != nil { - return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) - } + if sdata == "0000-00-00 00:00:00" || + sdata == "0001-01-01 00:00:00" { + } else if !strings.ContainsAny(sdata, "- :") { + // time stamp + sd, err := strconv.ParseInt(sdata, 10, 64) + if err == nil { + x = time.Unix(0, sd) + } + } else if len(sdata) > 19 { + x, err = time.Parse(time.RFC3339Nano, sdata) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) + } + } else if len(sdata) == 19 { + x, err = time.Parse("2006-01-02 15:04:05", sdata) + } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { + x, err = time.Parse("2006-01-02", sdata) + } else if col.SQLType.Name == Time { + if len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + } + st := fmt.Sprintf("2006-01-02 %v", sdata) + x, err = time.Parse("2006-01-02 15:04:05", st) + } else { + return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) + } + if err != nil { + return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) + } - v = x - fieldValue.Set(reflect.ValueOf(v)) - } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(fieldValue.Type()) - if table != nil { - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - if x != 0 { - structInter := reflect.New(fieldValue.Type()) - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(x).Get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist!") - } - } - } else { - return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) - } - } - case reflect.Ptr: - // !nashtsai! TODO merge duplicated codes above - //typeStr := fieldType.String() - switch fieldType { - // case "*string": - case reflect.TypeOf(&c_EMPTY_STRING): - x := string(data) - fieldValue.Set(reflect.ValueOf(&x)) - // case "*bool": - case reflect.TypeOf(&c_BOOL_DEFAULT): - d := string(data) - v, err := strconv.ParseBool(d) - if err != nil { - return errors.New("arg " + key + " as bool: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&v)) - // case "*complex64": - case reflect.TypeOf(&c_COMPLEX64_DEFAULT): - var x complex64 - err := json.Unmarshal(data, &x) - if err != nil { - session.Engine.LogSQL(err) - return err - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*complex128": - case reflect.TypeOf(&c_COMPLEX128_DEFAULT): - var x complex128 - err := json.Unmarshal(data, &x) - if err != nil { - session.Engine.LogSQL(err) - return err - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*float64": - case reflect.TypeOf(&c_FLOAT64_DEFAULT): - x, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return errors.New("arg " + key + " as float64: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*float32": - case reflect.TypeOf(&c_FLOAT32_DEFAULT): - var x float32 - x1, err := strconv.ParseFloat(string(data), 32) - if err != nil { - return errors.New("arg " + key + " as float32: " + err.Error()) - } - x = float32(x1) - fieldValue.Set(reflect.ValueOf(&x)) - // case "*time.Time": - case reflect.TypeOf(&c_TIME_DEFAULT): - sdata := strings.TrimSpace(string(data)) - var x time.Time - var err error + v = x + fieldValue.Set(reflect.ValueOf(v)) + } else if session.Statement.UseCascade { + table := session.Engine.autoMapType(fieldValue.Type()) + if table != nil { + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + if x != 0 { + structInter := reflect.New(fieldValue.Type()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(x).Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v = structInter.Elem().Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } + } + } else { + return errors.New("unsupported struct type in Scan: " + fieldValue.Type().String()) + } + } + case reflect.Ptr: + // !nashtsai! TODO merge duplicated codes above + //typeStr := fieldType.String() + switch fieldType { + // case "*string": + case reflect.TypeOf(&c_EMPTY_STRING): + x := string(data) + fieldValue.Set(reflect.ValueOf(&x)) + // case "*bool": + case reflect.TypeOf(&c_BOOL_DEFAULT): + d := string(data) + v, err := strconv.ParseBool(d) + if err != nil { + return errors.New("arg " + key + " as bool: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&v)) + // case "*complex64": + case reflect.TypeOf(&c_COMPLEX64_DEFAULT): + var x complex64 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*complex128": + case reflect.TypeOf(&c_COMPLEX128_DEFAULT): + var x complex128 + err := json.Unmarshal(data, &x) + if err != nil { + session.Engine.LogSQL(err) + return err + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*float64": + case reflect.TypeOf(&c_FLOAT64_DEFAULT): + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*float32": + case reflect.TypeOf(&c_FLOAT32_DEFAULT): + var x float32 + x1, err := strconv.ParseFloat(string(data), 32) + if err != nil { + return errors.New("arg " + key + " as float32: " + err.Error()) + } + x = float32(x1) + fieldValue.Set(reflect.ValueOf(&x)) + // case "*time.Time": + case reflect.TypeOf(&c_TIME_DEFAULT): + sdata := strings.TrimSpace(string(data)) + var x time.Time + var err error - if sdata == "0000-00-00 00:00:00" || - sdata == "0001-01-01 00:00:00" { - } else if !strings.ContainsAny(sdata, "- :") { - // time stamp - sd, err := strconv.ParseInt(sdata, 10, 64) - if err == nil { - x = time.Unix(0, sd) - } - } else if len(sdata) > 19 { - x, err = time.Parse(time.RFC3339Nano, sdata) - if err != nil { - x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) - } - } else if len(sdata) == 19 { - x, err = time.Parse("2006-01-02 15:04:05", sdata) - } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { - x, err = time.Parse("2006-01-02", sdata) - } else if col.SQLType.Name == Time { - if len(sdata) > 8 { - sdata = sdata[len(sdata)-8:] - } - st := fmt.Sprintf("2006-01-02 %v", sdata) - x, err = time.Parse("2006-01-02 15:04:05", st) - } else { - return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) - } - if err != nil { - return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) - } + if sdata == "0000-00-00 00:00:00" || + sdata == "0001-01-01 00:00:00" { + } else if !strings.ContainsAny(sdata, "- :") { + // time stamp + sd, err := strconv.ParseInt(sdata, 10, 64) + if err == nil { + x = time.Unix(0, sd) + } + } else if len(sdata) > 19 { + x, err = time.Parse(time.RFC3339Nano, sdata) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) + } + } else if len(sdata) == 19 { + x, err = time.Parse("2006-01-02 15:04:05", sdata) + } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { + x, err = time.Parse("2006-01-02", sdata) + } else if col.SQLType.Name == Time { + if len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + } + st := fmt.Sprintf("2006-01-02 %v", sdata) + x, err = time.Parse("2006-01-02 15:04:05", st) + } else { + return errors.New(fmt.Sprintf("unsupported time format %v", string(data))) + } + if err != nil { + return errors.New(fmt.Sprintf("unsupported time format %v: %v", string(data), err)) + } - v = x - fieldValue.Set(reflect.ValueOf(&x)) - // case "*uint64": - case reflect.TypeOf(&c_UINT64_DEFAULT): - var x uint64 - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*uint": - case reflect.TypeOf(&c_UINT_DEFAULT): - var x uint - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - x = uint(x1) - fieldValue.Set(reflect.ValueOf(&x)) - // case "*uint32": - case reflect.TypeOf(&c_UINT32_DEFAULT): - var x uint32 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - x = uint32(x1) - fieldValue.Set(reflect.ValueOf(&x)) - // case "*uint8": - case reflect.TypeOf(&c_UINT8_DEFAULT): - var x uint8 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - x = uint8(x1) - fieldValue.Set(reflect.ValueOf(&x)) - // case "*uint16": - case reflect.TypeOf(&c_UINT16_DEFAULT): - var x uint16 - x1, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - x = uint16(x1) - fieldValue.Set(reflect.ValueOf(&x)) - // case "*int64": - case reflect.TypeOf(&c_INT64_DEFAULT): - sdata := string(data) - var x int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { - if len(data) == 1 { - x = int64(data[0]) - } else { - x = 0 - } - //fmt.Println("######", x, data) - } else if strings.HasPrefix(sdata, "0x") { - x, err = strconv.ParseInt(sdata, 16, 64) - } else if strings.HasPrefix(sdata, "0") { - x, err = strconv.ParseInt(sdata, 8, 64) - } else { - x, err = strconv.ParseInt(sdata, 10, 64) - } - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*int": - case reflect.TypeOf(&c_INT_DEFAULT): - sdata := string(data) - var x int - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { - if len(data) == 1 { - x = int(data[0]) - } else { - x = 0 - } - //fmt.Println("######", x, data) - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int(x1) - } - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*int32": - case reflect.TypeOf(&c_INT32_DEFAULT): - sdata := string(data) - var x int32 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { - if len(data) == 1 { - x = int32(data[0]) - } else { - x = 0 - } - //fmt.Println("######", x, data) - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int32(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int32(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int32(x1) - } - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*int8": - case reflect.TypeOf(&c_INT8_DEFAULT): - sdata := string(data) - var x int8 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { - if len(data) == 1 { - x = int8(data[0]) - } else { - x = 0 - } - //fmt.Println("######", x, data) - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int8(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int8(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int8(x1) - } - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - // case "*int16": - case reflect.TypeOf(&c_INT16_DEFAULT): - sdata := string(data) - var x int16 - var x1 int64 - var err error - // for mysql, when use bit, it returned \x01 - if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { - if len(data) == 1 { - x = int16(data[0]) - } else { - x = 0 - } - //fmt.Println("######", x, data) - } else if strings.HasPrefix(sdata, "0x") { - x1, err = strconv.ParseInt(sdata, 16, 64) - x = int16(x1) - } else if strings.HasPrefix(sdata, "0") { - x1, err = strconv.ParseInt(sdata, 8, 64) - x = int16(x1) - } else { - x1, err = strconv.ParseInt(sdata, 10, 64) - x = int16(x1) - } - if err != nil { - return errors.New("arg " + key + " as int: " + err.Error()) - } - fieldValue.Set(reflect.ValueOf(&x)) - default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) - } - default: - return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) - } + v = x + fieldValue.Set(reflect.ValueOf(&x)) + // case "*uint64": + case reflect.TypeOf(&c_UINT64_DEFAULT): + var x uint64 + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*uint": + case reflect.TypeOf(&c_UINT_DEFAULT): + var x uint + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint(x1) + fieldValue.Set(reflect.ValueOf(&x)) + // case "*uint32": + case reflect.TypeOf(&c_UINT32_DEFAULT): + var x uint32 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint32(x1) + fieldValue.Set(reflect.ValueOf(&x)) + // case "*uint8": + case reflect.TypeOf(&c_UINT8_DEFAULT): + var x uint8 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint8(x1) + fieldValue.Set(reflect.ValueOf(&x)) + // case "*uint16": + case reflect.TypeOf(&c_UINT16_DEFAULT): + var x uint16 + x1, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + x = uint16(x1) + fieldValue.Set(reflect.ValueOf(&x)) + // case "*int64": + case reflect.TypeOf(&c_INT64_DEFAULT): + sdata := string(data) + var x int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int64(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x, err = strconv.ParseInt(sdata, 16, 64) + } else if strings.HasPrefix(sdata, "0") { + x, err = strconv.ParseInt(sdata, 8, 64) + } else { + x, err = strconv.ParseInt(sdata, 10, 64) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*int": + case reflect.TypeOf(&c_INT_DEFAULT): + sdata := string(data) + var x int + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*int32": + case reflect.TypeOf(&c_INT32_DEFAULT): + sdata := string(data) + var x int32 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int32(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int32(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int32(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int32(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*int8": + case reflect.TypeOf(&c_INT8_DEFAULT): + sdata := string(data) + var x int8 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int8(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int8(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int8(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int8(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + // case "*int16": + case reflect.TypeOf(&c_INT16_DEFAULT): + sdata := string(data) + var x int16 + var x1 int64 + var err error + // for mysql, when use bit, it returned \x01 + if col.SQLType.Name == Bit && + strings.Contains(session.Engine.DriverName, "mysql") { + if len(data) == 1 { + x = int16(data[0]) + } else { + x = 0 + } + //fmt.Println("######", x, data) + } else if strings.HasPrefix(sdata, "0x") { + x1, err = strconv.ParseInt(sdata, 16, 64) + x = int16(x1) + } else if strings.HasPrefix(sdata, "0") { + x1, err = strconv.ParseInt(sdata, 8, 64) + x = int16(x1) + } else { + x1, err = strconv.ParseInt(sdata, 10, 64) + x = int16(x1) + } + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + fieldValue.Set(reflect.ValueOf(&x)) + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + } + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + } - return nil + return nil } // convert a field value of a struct to interface for put into db func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (interface{}, error) { - if fieldValue.CanAddr() { - if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { - data, err := fieldConvert.ToDB() - if err != nil { - return 0, err - } else { - return string(data), nil - } - } - } - fieldType := fieldValue.Type() - k := fieldType.Kind() - if k == reflect.Ptr { - if fieldValue.IsNil() { - return nil, nil - } else if !fieldValue.IsValid() { - session.Engine.LogWarn("the field[", col.FieldName, "] is invalid") - return nil, nil - } else { - // !nashtsai! deference pointer type to instance type - fieldValue = fieldValue.Elem() - fieldType = fieldValue.Type() - k = fieldType.Kind() - } - } + if fieldValue.CanAddr() { + if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return 0, err + } else { + return string(data), nil + } + } + } + fieldType := fieldValue.Type() + k := fieldType.Kind() + if k == reflect.Ptr { + if fieldValue.IsNil() { + return nil, nil + } else if !fieldValue.IsValid() { + session.Engine.LogWarn("the field[", col.FieldName, "] is invalid") + return nil, nil + } else { + // !nashtsai! deference pointer type to instance type + fieldValue = fieldValue.Elem() + fieldType = fieldValue.Type() + k = fieldType.Kind() + } + } - switch k { - case reflect.Bool: - if fieldValue.Bool() { - return 1, nil - } else { - return 0, nil - } - case reflect.String: - return fieldValue.String(), nil - case reflect.Struct: - if fieldType.String() == "time.Time" { - if col.SQLType.Name == Time { - //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") - s := fieldValue.Interface().(time.Time).Format(time.RFC3339) - return s[11:19], nil - } else if col.SQLType.Name == Date { - return fieldValue.Interface().(time.Time).Format("2006-01-02"), nil - } else if col.SQLType.Name == TimeStampz { - return fieldValue.Interface().(time.Time).Format(time.RFC3339Nano), nil - } - return fieldValue.Interface(), nil - } - if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { - if fieldTable.PrimaryKey != "" { - pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) - return pkField.Interface(), nil - } else { - return 0, errors.New("no primary key") - } - } else { - return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type())) - } - case reflect.Complex64, reflect.Complex128: - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - session.Engine.LogSQL(err) - return 0, err - } - return string(bytes), nil - case reflect.Array, reflect.Slice, reflect.Map: - if !fieldValue.IsValid() { - return fieldValue.Interface(), nil - } + switch k { + case reflect.Bool: + if fieldValue.Bool() { + return 1, nil + } else { + return 0, nil + } + case reflect.String: + return fieldValue.String(), nil + case reflect.Struct: + if fieldType.String() == "time.Time" { + if col.SQLType.Name == Time { + //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") + s := fieldValue.Interface().(time.Time).Format(time.RFC3339) + return s[11:19], nil + } else if col.SQLType.Name == Date { + return fieldValue.Interface().(time.Time).Format("2006-01-02"), nil + } else if col.SQLType.Name == TimeStampz { + return fieldValue.Interface().(time.Time).Format(time.RFC3339Nano), nil + } + return fieldValue.Interface(), nil + } + if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { + if fieldTable.PrimaryKey != "" { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) + return pkField.Interface(), nil + } else { + return 0, errors.New("no primary key") + } + } else { + return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type())) + } + case reflect.Complex64, reflect.Complex128: + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + return string(bytes), nil + case reflect.Array, reflect.Slice, reflect.Map: + if !fieldValue.IsValid() { + return fieldValue.Interface(), nil + } - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - session.Engine.LogSQL(err) - return 0, err - } - return string(bytes), nil - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (k == reflect.Array || k == reflect.Slice) && - (fieldValue.Type().Elem().Kind() == reflect.Uint8) { - bytes = fieldValue.Bytes() - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - session.Engine.LogSQL(err) - return 0, err - } - } - return bytes, nil - } else { - return nil, ErrUnSupportedType - } - default: - return fieldValue.Interface(), nil - } + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (k == reflect.Array || k == reflect.Slice) && + (fieldValue.Type().Elem().Kind() == reflect.Uint8) { + bytes = fieldValue.Bytes() + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogSQL(err) + return 0, err + } + } + return bytes, nil + } else { + return nil, ErrUnSupportedType + } + default: + return fieldValue.Interface(), nil + } } func (session *Session) innerInsert(bean interface{}) (int64, error) { - table := session.Engine.autoMap(bean) - session.Statement.RefTable = table + table := session.Engine.autoMap(bean) + session.Statement.RefTable = table - // handle BeforeInsertProcessor - for _, closure := range session.beforeClosures { - closure(bean) - } - cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used + // handle BeforeInsertProcessor + for _, closure := range session.beforeClosures { + closure(bean) + } + cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used - if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { - processor.BeforeInsert() - } - // -- + if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok { + processor.BeforeInsert() + } + // -- - colNames, args, err := table.genCols(session, bean, false, false) - if err != nil { - return 0, err - } + colNames, args, err := table.genCols(session, bean, false, false) + if err != nil { + return 0, err + } - colPlaces := strings.Repeat("?, ", len(colNames)) - colPlaces = colPlaces[0 : len(colPlaces)-2] + colPlaces := strings.Repeat("?, ", len(colNames)) + colPlaces = colPlaces[0 : len(colPlaces)-2] - sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", - session.Engine.QuoteStr(), - session.Statement.TableName(), - session.Engine.QuoteStr(), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.Quote(", ")), - session.Engine.QuoteStr(), - colPlaces) + sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)", + session.Engine.QuoteStr(), + session.Statement.TableName(), + session.Engine.QuoteStr(), + session.Engine.QuoteStr(), + strings.Join(colNames, session.Engine.Quote(", ")), + session.Engine.QuoteStr(), + colPlaces) - handleAfterInsertProcessorFunc := func(bean interface{}) { + handleAfterInsertProcessorFunc := func(bean interface{}) { - if session.IsAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - processor.AfterInsert() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterInsertBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterInsertBeans[bean] = &afterClosures - } + if session.IsAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + processor.AfterInsert() + } + } else { + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 { + if value, has := session.afterInsertBeans[bean]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterInsertBeans[bean] = &afterClosures + } - } else { - if _, ok := interface{}(bean).(AfterInsertProcessor); ok { - session.afterInsertBeans[bean] = nil - } - } - } - cleanupProcessorsClosures(&session.afterClosures) // cleanup after used - } + } else { + if _, ok := interface{}(bean).(AfterInsertProcessor); ok { + session.afterInsertBeans[bean] = nil + } + } + } + cleanupProcessorsClosures(&session.afterClosures) // cleanup after used + } - // for postgres, many of them didn't implement lastInsertId, so we should - // implemented it ourself. - if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { - res, err := session.exec(sql, args...) - if err != nil { - return 0, err - } else { - handleAfterInsertProcessorFunc(bean) - } + // for postgres, many of them didn't implement lastInsertId, so we should + // implemented it ourself. + if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } else { + handleAfterInsertProcessorFunc(bean) + } - if table.Cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) - } + if table.Cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } - if table.Version != "" && session.Statement.checkVersion { - verValue := table.VersionColumn().ValueOf(bean) - if verValue.IsValid() && verValue.CanSet() { - verValue.SetInt(1) - } - } + if table.Version != "" && session.Statement.checkVersion { + verValue := table.VersionColumn().ValueOf(bean) + if verValue.IsValid() && verValue.CanSet() { + verValue.SetInt(1) + } + } - if table.PrimaryKey == "" || table.PKColumn().SQLType.IsText() { - return res.RowsAffected() - } + if table.PrimaryKey == "" || table.PKColumn().SQLType.IsText() { + return res.RowsAffected() + } - var id int64 = 0 - id, err = res.LastInsertId() - if err != nil || id <= 0 { - return res.RowsAffected() - } + var id int64 = 0 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return res.RowsAffected() + } - pkValue := table.PKColumn().ValueOf(bean) - if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { - return res.RowsAffected() - } + pkValue := table.PKColumn().ValueOf(bean) + if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { + return res.RowsAffected() + } - var v interface{} = id - switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: - v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - v = uint(id) - } - pkValue.Set(reflect.ValueOf(v)) + var v interface{} = id + switch pkValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + v = int(id) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + v = uint(id) + } + pkValue.Set(reflect.ValueOf(v)) - return res.RowsAffected() - } else { - sql = sql + " RETURNING (id)" - res, err := session.query(sql, args...) - if err != nil { - return 0, err - } else { - handleAfterInsertProcessorFunc(bean) - } + return res.RowsAffected() + } else { + sql = sql + " RETURNING (id)" + res, err := session.query(sql, args...) + if err != nil { + return 0, err + } else { + handleAfterInsertProcessorFunc(bean) + } - if table.Cacher != nil && session.Statement.UseCache { - session.cacheInsert(session.Statement.TableName()) - } + if table.Cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } - if table.Version != "" && session.Statement.checkVersion { - verValue := table.VersionColumn().ValueOf(bean) - if verValue.IsValid() && verValue.CanSet() { - verValue.SetInt(1) - } - } + if table.Version != "" && session.Statement.checkVersion { + verValue := table.VersionColumn().ValueOf(bean) + if verValue.IsValid() && verValue.CanSet() { + verValue.SetInt(1) + } + } - if len(res) < 1 { - return 0, errors.New("insert no error but not returned id") - } + if len(res) < 1 { + return 0, errors.New("insert no error but not returned id") + } - idByte := res[0][table.PrimaryKey] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil { - return 1, err - } + idByte := res[0][table.PrimaryKey] + id, err := strconv.ParseInt(string(idByte), 10, 64) + if err != nil { + return 1, err + } - pkValue := table.PKColumn().ValueOf(bean) - if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { - return 1, nil - } + pkValue := table.PKColumn().ValueOf(bean) + if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { + return 1, nil + } - var v interface{} = id - switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: - v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: - v = uint(id) - } - pkValue.Set(reflect.ValueOf(v)) + var v interface{} = id + switch pkValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + v = int(id) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + v = uint(id) + } + pkValue.Set(reflect.ValueOf(v)) - return 1, nil - } + return 1, nil + } } // Method InsertOne insert only one struct into database as a record. // The in parameter bean must a struct or a point to struct. The return // parameter is lastInsertId and error func (session *Session) InsertOne(bean interface{}) (int64, error) { - err := session.newDb() - if err != nil { - return 0, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return 0, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - return session.innerInsert(bean) + return session.innerInsert(bean) } func (statement *Statement) convertUpdateSql(sql string) (string, string) { - if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { - return "", "" - } - sqls := splitNNoCase(sql, "where", 2) - if len(sqls) != 2 { - if len(sqls) == 1 { - return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - statement.Engine.Quote(statement.RefTable.PrimaryKey), - statement.Engine.Quote(statement.RefTable.Name)) - } - return "", "" - } + if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { + return "", "" + } + sqls := splitNNoCase(sql, "where", 2) + if len(sqls) != 2 { + if len(sqls) == 1 { + return sqls[0], fmt.Sprintf("SELECT %v FROM %v", + statement.Engine.Quote(statement.RefTable.PrimaryKey), + statement.Engine.Quote(statement.RefTable.Name)) + } + return "", "" + } - var whereStr = sqls[1] + var whereStr = sqls[1] - //TODO: for postgres only, if any other database? - if strings.Contains(sqls[1], "$") { - dollers := strings.Split(sqls[1], "$") - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf("$%v %v", i+1, ccs[1]) - } - } + //TODO: for postgres only, if any other database? + if strings.Contains(sqls[1], "$") { + dollers := strings.Split(sqls[1], "$") + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf("$%v %v", i+1, ccs[1]) + } + } - return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - statement.Engine.Quote(statement.RefTable.PrimaryKey), statement.Engine.Quote(statement.TableName()), - whereStr) + return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", + statement.Engine.Quote(statement.RefTable.PrimaryKey), statement.Engine.Quote(statement.TableName()), + whereStr) } func (session *Session) cacheInsert(tables ...string) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { - return ErrCacheFailed - } + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } - table := session.Statement.RefTable - cacher := table.Cacher + table := session.Statement.RefTable + cacher := table.Cacher - for _, t := range tables { - session.Engine.LogDebug("cache clear:", t) - cacher.ClearIds(t) - } + for _, t := range tables { + session.Engine.LogDebug("cache clear:", t) + cacher.ClearIds(t) + } - return nil + return nil } func (session *Session) cacheUpdate(sql string, args ...interface{}) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { - return ErrCacheFailed - } + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } - oldhead, newsql := session.Statement.convertUpdateSql(sql) - if newsql == "" { - return ErrCacheFailed - } - for _, filter := range session.Engine.Filters { - newsql = filter.Do(newsql, session) - } - session.Engine.LogDebug("[xorm:cacheUpdate] new sql", oldhead, newsql) + oldhead, newsql := session.Statement.convertUpdateSql(sql) + if newsql == "" { + return ErrCacheFailed + } + for _, filter := range session.Engine.Filters { + newsql = filter.Do(newsql, session) + } + session.Engine.LogDebug("[xorm:cacheUpdate] new sql", oldhead, newsql) - var nStart int - if len(args) > 0 { - if strings.Index(sql, "?") > -1 { - nStart = strings.Count(oldhead, "?") - } else { - // only for pq, TODO: if any other databse? - nStart = strings.Count(oldhead, "$") - } - } - table := session.Statement.RefTable - cacher := table.Cacher - tableName := session.Statement.TableName() - session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:]) - ids, err := getCacheSql(cacher, tableName, newsql, args[nStart:]) - if err != nil { - resultsSlice, err := session.query(newsql, args[nStart:]...) - if err != nil { - return err - } - session.Engine.LogDebug("[xorm:cacheUpdate] find updated id", resultsSlice) + var nStart int + if len(args) > 0 { + if strings.Index(sql, "?") > -1 { + nStart = strings.Count(oldhead, "?") + } else { + // only for pq, TODO: if any other databse? + nStart = strings.Count(oldhead, "$") + } + } + table := session.Statement.RefTable + cacher := table.Cacher + tableName := session.Statement.TableName() + session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:]) + ids, err := getCacheSql(cacher, tableName, newsql, args[nStart:]) + if err != nil { + resultsSlice, err := session.query(newsql, args[nStart:]...) + if err != nil { + return err + } + session.Engine.LogDebug("[xorm:cacheUpdate] find updated id", resultsSlice) - ids = make([]int64, 0) - if len(resultsSlice) > 0 { - for _, data := range resultsSlice { - var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { - return errors.New("no id") - } else { - id, err = strconv.ParseInt(string(v), 10, 64) - if err != nil { - return err - } - } - ids = append(ids, id) - } - } - } /*else { - session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) - cacher.DelIds(tableName, genSqlKey(newsql, args)) - }*/ + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + for _, data := range resultsSlice { + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + } + ids = append(ids, id) + } + } + } /*else { + session.Engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) + cacher.DelIds(tableName, genSqlKey(newsql, args)) + }*/ - for _, id := range ids { - if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := splitNNoCase(sql, "where", 2) - if len(sqls) == 0 || len(sqls) > 2 { - return ErrCacheFailed - } + for _, id := range ids { + if bean := cacher.GetBean(tableName, id); bean != nil { + sqls := splitNNoCase(sql, "where", 2) + if len(sqls) == 0 || len(sqls) > 2 { + return ErrCacheFailed + } - sqls = splitNNoCase(sqls[0], "set", 2) - if len(sqls) != 2 { - return ErrCacheFailed - } - kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") - for idx, kv := range kvs { - sps := strings.SplitN(kv, "=", 2) - sps2 := strings.Split(sps[0], ".") - colName := sps2[len(sps2)-1] - if strings.Contains(colName, "`") { - colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) - } else if strings.Contains(colName, session.Engine.QuoteStr()) { - colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) - } else { - session.Engine.LogDebug("[xorm:cacheUpdate] cannot find column", tableName, colName) - return ErrCacheFailed - } + sqls = splitNNoCase(sqls[0], "set", 2) + if len(sqls) != 2 { + return ErrCacheFailed + } + kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") + for idx, kv := range kvs { + sps := strings.SplitN(kv, "=", 2) + sps2 := strings.Split(sps[0], ".") + colName := sps2[len(sps2)-1] + if strings.Contains(colName, "`") { + colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) + } else if strings.Contains(colName, session.Engine.QuoteStr()) { + colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) + } else { + session.Engine.LogDebug("[xorm:cacheUpdate] cannot find column", tableName, colName) + return ErrCacheFailed + } - if col, ok := table.Columns[colName]; ok { - fieldValue := col.ValueOf(bean) - session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.Statement.checkVersion { - fieldValue.SetInt(fieldValue.Int() + 1) - fmt.Println("-----", fieldValue) - } else { - fieldValue.Set(reflect.ValueOf(args[idx])) - fmt.Println("xxxxxx", fieldValue) - } - } else { - session.Engine.LogError("[xorm:cacheUpdate] ERROR: column %v is not table %v's", - colName, table.Name) - } - } + if col, ok := table.Columns[colName]; ok { + fieldValue := col.ValueOf(bean) + session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) + if col.IsVersion && session.Statement.checkVersion { + fieldValue.SetInt(fieldValue.Int() + 1) + fmt.Println("-----", fieldValue) + } else { + fieldValue.Set(reflect.ValueOf(args[idx])) + fmt.Println("xxxxxx", fieldValue) + } + } else { + session.Engine.LogError("[xorm:cacheUpdate] ERROR: column %v is not table %v's", + colName, table.Name) + } + } - session.Engine.LogDebug("[xorm:cacheUpdate] update cache", tableName, id, bean) - cacher.PutBean(tableName, id, bean) - } - } - session.Engine.LogDebug("[xorm:cacheUpdate] clear cached table sql:", tableName) - cacher.ClearIds(tableName) - return nil + session.Engine.LogDebug("[xorm:cacheUpdate] update cache", tableName, id, bean) + cacher.PutBean(tableName, id, bean) + } + } + session.Engine.LogDebug("[xorm:cacheUpdate] clear cached table sql:", tableName) + cacher.ClearIds(tableName) + return nil } // Update records, bean's non-empty fields are updated contents, @@ -2445,325 +2444,325 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { // You should call UseBool if you have bool to use. // 2.float32 & float64 may be not inexact as conditions func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { - err := session.newDb() - if err != nil { - return 0, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return 0, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - t := rType(bean) + t := rType(bean) - var colNames []string - var args []interface{} - var table *Table + var colNames []string + var args []interface{} + var table *Table - // handle before update processors - for _, closure := range session.beforeClosures { - closure(bean) - } - cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used - if processor, ok := interface{}(bean).(BeforeUpdateProcessor); ok { - processor.BeforeUpdate() - } - // -- + // handle before update processors + for _, closure := range session.beforeClosures { + closure(bean) + } + cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used + if processor, ok := interface{}(bean).(BeforeUpdateProcessor); ok { + processor.BeforeUpdate() + } + // -- - if t.Kind() == reflect.Struct { - table = session.Engine.autoMap(bean) - session.Statement.RefTable = table + if t.Kind() == reflect.Struct { + table = session.Engine.autoMap(bean) + session.Statement.RefTable = table - if session.Statement.ColumnStr == "" { - colNames, args = buildConditions(session.Engine, table, bean, false, false, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) - } else { - colNames, args, err = table.genCols(session, bean, true, true) - if err != nil { - return 0, err - } - } - } else if t.Kind() == reflect.Map { - if session.Statement.RefTable == nil { - return 0, ErrTableNotFound - } - table = session.Statement.RefTable - colNames = make([]string, 0) - args = make([]interface{}, 0) - bValue := reflect.Indirect(reflect.ValueOf(bean)) + if session.Statement.ColumnStr == "" { + colNames, args = buildConditions(session.Engine, table, bean, false, false, + false, session.Statement.allUseBool, session.Statement.boolColumnMap) + } else { + colNames, args, err = table.genCols(session, bean, true, true) + if err != nil { + return 0, err + } + } + } else if t.Kind() == reflect.Map { + if session.Statement.RefTable == nil { + return 0, ErrTableNotFound + } + table = session.Statement.RefTable + colNames = make([]string, 0) + args = make([]interface{}, 0) + bValue := reflect.Indirect(reflect.ValueOf(bean)) - for _, v := range bValue.MapKeys() { - colNames = append(colNames, session.Engine.Quote(v.String())+" = ?") - args = append(args, bValue.MapIndex(v).Interface()) - } - } else { - return 0, ErrParamsType - } + for _, v := range bValue.MapKeys() { + colNames = append(colNames, session.Engine.Quote(v.String())+" = ?") + args = append(args, bValue.MapIndex(v).Interface()) + } + } else { + return 0, ErrParamsType + } - if session.Statement.UseAutoTime && table.Updated != "" { - colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") - args = append(args, time.Now()) - } + if session.Statement.UseAutoTime && table.Updated != "" { + colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?") + args = append(args, time.Now()) + } - var condiColNames []string - var condiArgs []interface{} + var condiColNames []string + var condiArgs []interface{} - if len(condiBean) > 0 { - condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) - } + if len(condiBean) > 0 { + condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, + false, session.Statement.allUseBool, session.Statement.boolColumnMap) + } - var condition = "" - session.Statement.processIdParam() - st := session.Statement - defer session.Statement.Init() - if st.WhereStr != "" { - condition = fmt.Sprintf("%v", st.WhereStr) - } + var condition = "" + session.Statement.processIdParam() + st := session.Statement + defer session.Statement.Init() + if st.WhereStr != "" { + condition = fmt.Sprintf("%v", st.WhereStr) + } - if condition == "" { - if len(condiColNames) > 0 { - condition = fmt.Sprintf("%v", strings.Join(condiColNames, " AND ")) - } - } else { - if len(condiColNames) > 0 { - condition = fmt.Sprintf("(%v) AND (%v)", condition, strings.Join(condiColNames, " AND ")) - } - } + if condition == "" { + if len(condiColNames) > 0 { + condition = fmt.Sprintf("%v", strings.Join(condiColNames, " AND ")) + } + } else { + if len(condiColNames) > 0 { + condition = fmt.Sprintf("(%v) AND (%v)", condition, strings.Join(condiColNames, " AND ")) + } + } - var sql, inSql string - var inArgs []interface{} - if table.Version != "" && session.Statement.checkVersion { - if condition != "" { - condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition, - session.Engine.Quote(table.Version)) - } else { - condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version)) - } - inSql, inArgs = session.Statement.genInSql() - if len(inSql) > 0 { - if condition != "" { - condition += " AND " + inSql - } else { - condition = "WHERE " + inSql - } - } + var sql, inSql string + var inArgs []interface{} + if table.Version != "" && session.Statement.checkVersion { + if condition != "" { + condition = fmt.Sprintf("WHERE (%v) AND %v = ?", condition, + session.Engine.Quote(table.Version)) + } else { + condition = fmt.Sprintf("WHERE %v = ?", session.Engine.Quote(table.Version)) + } + inSql, inArgs = session.Statement.genInSql() + if len(inSql) > 0 { + if condition != "" { + condition += " AND " + inSql + } else { + condition = "WHERE " + inSql + } + } - sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", - session.Engine.Quote(session.Statement.TableName()), - strings.Join(colNames, ", "), - session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", - condition) + sql = fmt.Sprintf("UPDATE %v SET %v, %v %v", + session.Engine.Quote(session.Statement.TableName()), + strings.Join(colNames, ", "), + session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", + condition) - condiArgs = append(condiArgs, table.VersionColumn().ValueOf(bean).Interface()) - } else { - if condition != "" { - condition = "WHERE " + condition - } - inSql, inArgs = session.Statement.genInSql() - if len(inSql) > 0 { - if condition != "" { - condition += " AND " + inSql - } else { - condition = "WHERE " + inSql - } - } + condiArgs = append(condiArgs, table.VersionColumn().ValueOf(bean).Interface()) + } else { + if condition != "" { + condition = "WHERE " + condition + } + inSql, inArgs = session.Statement.genInSql() + if len(inSql) > 0 { + if condition != "" { + condition += " AND " + inSql + } else { + condition = "WHERE " + inSql + } + } - sql = fmt.Sprintf("UPDATE %v SET %v %v", - session.Engine.Quote(session.Statement.TableName()), - strings.Join(colNames, ", "), - condition) - } + sql = fmt.Sprintf("UPDATE %v SET %v %v", + session.Engine.Quote(session.Statement.TableName()), + strings.Join(colNames, ", "), + condition) + } - args = append(args, st.Params...) - args = append(args, inArgs...) - args = append(args, condiArgs...) + args = append(args, st.Params...) + args = append(args, inArgs...) + args = append(args, condiArgs...) - res, err := session.exec(sql, args...) - if err != nil { - return 0, err - } + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } - if table.Cacher != nil && session.Statement.UseCache { - //session.cacheUpdate(sql, args...) - table.Cacher.ClearIds(session.Statement.TableName()) - table.Cacher.ClearBeans(session.Statement.TableName()) - } + if table.Cacher != nil && session.Statement.UseCache { + //session.cacheUpdate(sql, args...) + table.Cacher.ClearIds(session.Statement.TableName()) + table.Cacher.ClearBeans(session.Statement.TableName()) + } - // handle after update processors - if session.IsAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.Engine.LogDebug(session.Statement.TableName(), " has after update processor") - processor.AfterUpdate() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterUpdateBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterUpdateBeans[bean] = &afterClosures - } + // handle after update processors + if session.IsAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { + session.Engine.LogDebug(session.Statement.TableName(), " has after update processor") + processor.AfterUpdate() + } + } else { + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 { + if value, has := session.afterUpdateBeans[bean]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterUpdateBeans[bean] = &afterClosures + } - } else { - if _, ok := interface{}(bean).(AfterInsertProcessor); ok { - session.afterUpdateBeans[bean] = nil - } - } - } - cleanupProcessorsClosures(&session.afterClosures) // cleanup after used - // -- + } else { + if _, ok := interface{}(bean).(AfterInsertProcessor); ok { + session.afterUpdateBeans[bean] = nil + } + } + } + cleanupProcessorsClosures(&session.afterClosures) // cleanup after used + // -- - return res.RowsAffected() + return res.RowsAffected() } func (session *Session) cacheDelete(sql string, args ...interface{}) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { - return ErrCacheFailed - } + if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + return ErrCacheFailed + } - for _, filter := range session.Engine.Filters { - sql = filter.Do(sql, session) - } + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } - newsql := session.Statement.convertIdSql(sql) - if newsql == "" { - return ErrCacheFailed - } + newsql := session.Statement.convertIdSql(sql) + if newsql == "" { + return ErrCacheFailed + } - cacher := session.Statement.RefTable.Cacher - tableName := session.Statement.TableName() - ids, err := getCacheSql(cacher, tableName, newsql, args) - if err != nil { - resultsSlice, err := session.query(newsql, args...) - if err != nil { - return err - } - ids = make([]int64, 0) - if len(resultsSlice) > 0 { - for _, data := range resultsSlice { - var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { - return errors.New("no id") - } else { - id, err = strconv.ParseInt(string(v), 10, 64) - if err != nil { - return err - } - } - ids = append(ids, id) - } - } - } /*else { - session.Engine.LogDebug("delete cache sql %v", newsql) - cacher.DelIds(tableName, genSqlKey(newsql, args)) - }*/ + cacher := session.Statement.RefTable.Cacher + tableName := session.Statement.TableName() + ids, err := getCacheSql(cacher, tableName, newsql, args) + if err != nil { + resultsSlice, err := session.query(newsql, args...) + if err != nil { + return err + } + ids = make([]int64, 0) + if len(resultsSlice) > 0 { + for _, data := range resultsSlice { + var id int64 + if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + return errors.New("no id") + } else { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + } + ids = append(ids, id) + } + } + } /*else { + session.Engine.LogDebug("delete cache sql %v", newsql) + cacher.DelIds(tableName, genSqlKey(newsql, args)) + }*/ - for _, id := range ids { - session.Engine.LogDebug("[xorm:cacheDelete] delete cache obj", tableName, id) - cacher.DelBean(tableName, id) - } - session.Engine.LogDebug("[xorm:cacheDelete] clear cache table", tableName) - cacher.ClearIds(tableName) - return nil + for _, id := range ids { + session.Engine.LogDebug("[xorm:cacheDelete] delete cache obj", tableName, id) + cacher.DelBean(tableName, id) + } + session.Engine.LogDebug("[xorm:cacheDelete] clear cache table", tableName) + cacher.ClearIds(tableName) + return nil } // Delete records, bean's non-empty fields are conditions func (session *Session) Delete(bean interface{}) (int64, error) { - err := session.newDb() - if err != nil { - return 0, err - } - defer session.Statement.Init() - if session.IsAutoClose { - defer session.Close() - } + err := session.newDb() + if err != nil { + return 0, err + } + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } - // handle before delete processors - for _, closure := range session.beforeClosures { - closure(bean) - } - cleanupProcessorsClosures(&session.beforeClosures) + // handle before delete processors + for _, closure := range session.beforeClosures { + closure(bean) + } + cleanupProcessorsClosures(&session.beforeClosures) - if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { - processor.BeforeDelete() - } - // -- + if processor, ok := interface{}(bean).(BeforeDeleteProcessor); ok { + processor.BeforeDelete() + } + // -- - table := session.Engine.autoMap(bean) - session.Statement.RefTable = table - colNames, args := buildConditions(session.Engine, table, bean, true, true, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) + table := session.Engine.autoMap(bean) + session.Statement.RefTable = table + colNames, args := buildConditions(session.Engine, table, bean, true, true, + false, session.Statement.allUseBool, session.Statement.boolColumnMap) - var condition = "" + var condition = "" - session.Statement.processIdParam() - if session.Statement.WhereStr != "" { - condition = session.Statement.WhereStr - if len(colNames) > 0 { - condition += " AND " + strings.Join(colNames, " AND ") - } - } else { - condition = strings.Join(colNames, " AND ") - } - inSql, inArgs := session.Statement.genInSql() - if len(inSql) > 0 { - if len(condition) > 0 { - condition += " AND " - } - condition += inSql - args = append(args, inArgs...) - } - if len(condition) == 0 { - return 0, ErrNeedDeletedCond - } + session.Statement.processIdParam() + if session.Statement.WhereStr != "" { + condition = session.Statement.WhereStr + if len(colNames) > 0 { + condition += " AND " + strings.Join(colNames, " AND ") + } + } else { + condition = strings.Join(colNames, " AND ") + } + inSql, inArgs := session.Statement.genInSql() + if len(inSql) > 0 { + if len(condition) > 0 { + condition += " AND " + } + condition += inSql + args = append(args, inArgs...) + } + if len(condition) == 0 { + return 0, ErrNeedDeletedCond + } - sql := fmt.Sprintf("DELETE FROM %v WHERE %v", - session.Engine.Quote(session.Statement.TableName()), condition) + sql := fmt.Sprintf("DELETE FROM %v WHERE %v", + session.Engine.Quote(session.Statement.TableName()), condition) - args = append(session.Statement.Params, args...) + args = append(session.Statement.Params, args...) - if table.Cacher != nil && session.Statement.UseCache { - session.cacheDelete(sql, args...) - } + if table.Cacher != nil && session.Statement.UseCache { + session.cacheDelete(sql, args...) + } - res, err := session.exec(sql, args...) - if err != nil { - return 0, err - } + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } - // handle after delete processors - if session.IsAutoCommit { - for _, closure := range session.afterClosures { - closure(bean) - } - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } else { - lenAfterClosures := len(session.afterClosures) - if lenAfterClosures > 0 { - if value, has := session.afterDeleteBeans[bean]; has && value != nil { - *value = append(*value, session.afterClosures...) - } else { - afterClosures := make([]func(interface{}), lenAfterClosures) - copy(afterClosures, session.afterClosures) - session.afterDeleteBeans[bean] = &afterClosures - } + // handle after delete processors + if session.IsAutoCommit { + for _, closure := range session.afterClosures { + closure(bean) + } + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() + } + } else { + lenAfterClosures := len(session.afterClosures) + if lenAfterClosures > 0 { + if value, has := session.afterDeleteBeans[bean]; has && value != nil { + *value = append(*value, session.afterClosures...) + } else { + afterClosures := make([]func(interface{}), lenAfterClosures) + copy(afterClosures, session.afterClosures) + session.afterDeleteBeans[bean] = &afterClosures + } - } else { - if _, ok := interface{}(bean).(AfterInsertProcessor); ok { - session.afterDeleteBeans[bean] = nil - } - } - } - cleanupProcessorsClosures(&session.afterClosures) - // -- + } else { + if _, ok := interface{}(bean).(AfterInsertProcessor); ok { + session.afterDeleteBeans[bean] = nil + } + } + } + cleanupProcessorsClosures(&session.afterClosures) + // -- - return res.RowsAffected() + return res.RowsAffected() } diff --git a/sqlite3.go b/sqlite3.go index eb42e999..84a9d1b0 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1,223 +1,229 @@ package xorm import ( - "database/sql" - "strings" + "database/sql" + "strings" ) type sqlite3 struct { - base + base +} + +type sqlite3Parser struct { +} + +func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) { + return &uri{dbType: SQLITE, dbName: dataSourceName}, nil } func (db *sqlite3) Init(drivername, dataSourceName string) error { - db.base.init(drivername, dataSourceName) - return nil + return db.base.init(&sqlite3Parser{}, drivername, dataSourceName) } func (db *sqlite3) SqlType(c *Column) string { - switch t := c.SQLType.Name; t { - case Date, DateTime, TimeStamp, Time: - return Numeric - case TimeStampz: - return Text - case Char, Varchar, TinyText, Text, MediumText, LongText: - return Text - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: - return Integer - case Float, Double, Real: - return Real - case Decimal, Numeric: - return Numeric - case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: - return Blob - case Serial, BigSerial: - c.IsPrimaryKey = true - c.IsAutoIncrement = true - c.Nullable = false - return Integer - default: - return t - } + switch t := c.SQLType.Name; t { + case Date, DateTime, TimeStamp, Time: + return Numeric + case TimeStampz: + return Text + case Char, Varchar, TinyText, Text, MediumText, LongText: + return Text + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: + return Integer + case Float, Double, Real: + return Real + case Decimal, Numeric: + return Numeric + case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: + return Blob + case Serial, BigSerial: + c.IsPrimaryKey = true + c.IsAutoIncrement = true + c.Nullable = false + return Integer + default: + return t + } } func (db *sqlite3) SupportInsertMany() bool { - return true + return true } func (db *sqlite3) QuoteStr() string { - return "`" + return "`" } func (db *sqlite3) AutoIncrStr() string { - return "AUTOINCREMENT" + return "AUTOINCREMENT" } func (db *sqlite3) SupportEngine() bool { - return false + return false } func (db *sqlite3) SupportCharset() bool { - return false + return false } func (db *sqlite3) IndexOnTable() bool { - return false + return false } func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{idxName} - return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args + args := []interface{}{idxName} + return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args } func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName} - sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - return sql, args + args := []interface{}{tableName} + sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" + return sql, args } func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{tableName} - s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } + args := []interface{}{tableName} + s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } - var sql string - for _, record := range res { - for name, content := range record { - if name == "sql" { - sql = string(content) - } - } - } + var sql string + for _, record := range res { + for name, content := range record { + if name == "sql" { + sql = string(content) + } + } + } - nStart := strings.Index(sql, "(") - nEnd := strings.Index(sql, ")") - colCreates := strings.Split(sql[nStart+1:nEnd], ",") - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, colStr := range colCreates { - fields := strings.Fields(strings.TrimSpace(colStr)) - col := new(Column) - col.Indexes = make(map[string]bool) - col.Nullable = true - for idx, field := range fields { - if idx == 0 { - col.Name = strings.Trim(field, "`[] ") - continue - } else if idx == 1 { - col.SQLType = SQLType{field, 0, 0} - } - switch field { - case "PRIMARY": - col.IsPrimaryKey = true - case "AUTOINCREMENT": - col.IsAutoIncrement = true - case "NULL": - if fields[idx-1] == "NOT" { - col.Nullable = false - } else { - col.Nullable = true - } - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } - return colSeq, cols, nil + nStart := strings.Index(sql, "(") + nEnd := strings.Index(sql, ")") + colCreates := strings.Split(sql[nStart+1:nEnd], ",") + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, colStr := range colCreates { + fields := strings.Fields(strings.TrimSpace(colStr)) + col := new(Column) + col.Indexes = make(map[string]bool) + col.Nullable = true + for idx, field := range fields { + if idx == 0 { + col.Name = strings.Trim(field, "`[] ") + continue + } else if idx == 1 { + col.SQLType = SQLType{field, 0, 0} + } + switch field { + case "PRIMARY": + col.IsPrimaryKey = true + case "AUTOINCREMENT": + col.IsAutoIncrement = true + case "NULL": + if fields[idx-1] == "NOT" { + col.Nullable = false + } else { + col.Nullable = true + } + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + return colSeq, cols, nil } func (db *sqlite3) GetTables() ([]*Table, error) { - args := []interface{}{} - s := "SELECT name FROM sqlite_master WHERE type='table'" + args := []interface{}{} + s := "SELECT name FROM sqlite_master WHERE type='table'" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "name": - table.Name = string(content) - } - } - if table.Name == "sqlite_sequence" { - continue - } - tables = append(tables, table) - } - return tables, nil + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "name": + table.Name = string(content) + } + } + if table.Name == "sqlite_sequence" { + continue + } + tables = append(tables, table) + } + return tables, nil } func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{tableName} - s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{tableName} + s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - indexes := make(map[string]*Index, 0) - for _, record := range res { - var sql string - index := new(Index) - for name, content := range record { - if name == "sql" { - sql = string(content) - } - } + indexes := make(map[string]*Index, 0) + for _, record := range res { + var sql string + index := new(Index) + for name, content := range record { + if name == "sql" { + sql = string(content) + } + } - nNStart := strings.Index(sql, "INDEX") - nNEnd := strings.Index(sql, "ON") - indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") - //fmt.Println(indexName) - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - index.Name = indexName[5+len(tableName) : len(indexName)] - } else { - index.Name = indexName - } + nNStart := strings.Index(sql, "INDEX") + nNEnd := strings.Index(sql, "ON") + indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") + //fmt.Println(indexName) + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + index.Name = indexName[5+len(tableName) : len(indexName)] + } else { + index.Name = indexName + } - if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { - index.Type = UniqueType - } else { - index.Type = IndexType - } + if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { + index.Type = UniqueType + } else { + index.Type = IndexType + } - nStart := strings.Index(sql, "(") - nEnd := strings.Index(sql, ")") - colIndexes := strings.Split(sql[nStart+1:nEnd], ",") + nStart := strings.Index(sql, "(") + nEnd := strings.Index(sql, ")") + colIndexes := strings.Split(sql[nStart+1:nEnd], ",") - index.Cols = make([]string, 0) - for _, col := range colIndexes { - index.Cols = append(index.Cols, strings.Trim(col, "` []")) - } - indexes[index.Name] = index - } + index.Cols = make([]string, 0) + for _, col := range colIndexes { + index.Cols = append(index.Cols, strings.Trim(col, "` []")) + } + indexes[index.Name] = index + } - return indexes, nil + return indexes, nil } diff --git a/statement.go b/statement.go index feafe873..2ceeee69 100644 --- a/statement.go +++ b/statement.go @@ -1,144 +1,144 @@ package xorm import ( - "bytes" - "fmt" - "reflect" - //"strconv" - "encoding/json" - "strings" - "time" + //"bytes" + "fmt" + "reflect" + //"strconv" + "encoding/json" + "strings" + "time" ) // !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparision var ( - c_EMPTY_STRING = "" - c_BOOL_DEFAULT = false - c_COMPLEX64_DEFAULT = complex64(0) - c_COMPLEX128_DEFAULT = complex128(0) - c_FLOAT32_DEFAULT = float32(0) - c_FLOAT64_DEFAULT = float64(0) - c_INT64_DEFAULT = int64(0) - c_UINT64_DEFAULT = uint64(0) - c_INT32_DEFAULT = int32(0) - c_UINT32_DEFAULT = uint32(0) - c_INT16_DEFAULT = int16(0) - c_UINT16_DEFAULT = uint16(0) - c_INT8_DEFAULT = int8(0) - c_UINT8_DEFAULT = uint8(0) - c_INT_DEFAULT = int(0) - c_UINT_DEFAULT = uint(0) - c_TIME_DEFAULT time.Time = time.Unix(0, 0) + c_EMPTY_STRING = "" + c_BOOL_DEFAULT = false + c_COMPLEX64_DEFAULT = complex64(0) + c_COMPLEX128_DEFAULT = complex128(0) + c_FLOAT32_DEFAULT = float32(0) + c_FLOAT64_DEFAULT = float64(0) + c_INT64_DEFAULT = int64(0) + c_UINT64_DEFAULT = uint64(0) + c_INT32_DEFAULT = int32(0) + c_UINT32_DEFAULT = uint32(0) + c_INT16_DEFAULT = int16(0) + c_UINT16_DEFAULT = uint16(0) + c_INT8_DEFAULT = int8(0) + c_UINT8_DEFAULT = uint8(0) + c_INT_DEFAULT = int(0) + c_UINT_DEFAULT = uint(0) + c_TIME_DEFAULT time.Time = time.Unix(0, 0) ) // statement save all the sql info for executing SQL type Statement struct { - RefTable *Table - Engine *Engine - Start int - LimitN int - WhereStr string - IdParam *PK - Params []interface{} - OrderStr string - JoinStr string - GroupByStr string - HavingStr string - ColumnStr string - columnMap map[string]bool - OmitStr string - ConditionStr string - AltTableName string - RawSQL string - RawParams []interface{} - UseCascade bool - UseAutoJoin bool - StoreEngine string - Charset string - BeanArgs []interface{} - UseCache bool - UseAutoTime bool - IsDistinct bool - allUseBool bool - checkVersion bool - boolColumnMap map[string]bool - inColumns map[string][]interface{} + RefTable *Table + Engine *Engine + Start int + LimitN int + WhereStr string + IdParam *PK + Params []interface{} + OrderStr string + JoinStr string + GroupByStr string + HavingStr string + ColumnStr string + columnMap map[string]bool + OmitStr string + ConditionStr string + AltTableName string + RawSQL string + RawParams []interface{} + UseCascade bool + UseAutoJoin bool + StoreEngine string + Charset string + BeanArgs []interface{} + UseCache bool + UseAutoTime bool + IsDistinct bool + allUseBool bool + checkVersion bool + boolColumnMap map[string]bool + inColumns map[string][]interface{} } // init func (statement *Statement) Init() { - statement.RefTable = nil - statement.Start = 0 - statement.LimitN = 0 - statement.WhereStr = "" - statement.Params = make([]interface{}, 0) - statement.OrderStr = "" - statement.UseCascade = true - statement.JoinStr = "" - statement.GroupByStr = "" - statement.HavingStr = "" - statement.ColumnStr = "" - statement.OmitStr = "" - statement.columnMap = make(map[string]bool) - statement.ConditionStr = "" - statement.AltTableName = "" - statement.RawSQL = "" - statement.RawParams = make([]interface{}, 0) - statement.BeanArgs = make([]interface{}, 0) - statement.UseCache = statement.Engine.UseCache - statement.UseAutoTime = true - statement.IsDistinct = false - statement.allUseBool = false - statement.boolColumnMap = make(map[string]bool) - statement.checkVersion = true - statement.inColumns = make(map[string][]interface{}) + statement.RefTable = nil + statement.Start = 0 + statement.LimitN = 0 + statement.WhereStr = "" + statement.Params = make([]interface{}, 0) + statement.OrderStr = "" + statement.UseCascade = true + statement.JoinStr = "" + statement.GroupByStr = "" + statement.HavingStr = "" + statement.ColumnStr = "" + statement.OmitStr = "" + statement.columnMap = make(map[string]bool) + statement.ConditionStr = "" + statement.AltTableName = "" + statement.RawSQL = "" + statement.RawParams = make([]interface{}, 0) + statement.BeanArgs = make([]interface{}, 0) + statement.UseCache = statement.Engine.UseCache + statement.UseAutoTime = true + statement.IsDistinct = false + statement.allUseBool = false + statement.boolColumnMap = make(map[string]bool) + statement.checkVersion = true + statement.inColumns = make(map[string][]interface{}) } // add the raw sql statement func (statement *Statement) Sql(querystring string, args ...interface{}) *Statement { - statement.RawSQL = querystring - statement.RawParams = args - return statement + statement.RawSQL = querystring + statement.RawParams = args + return statement } // add Where statment func (statement *Statement) Where(querystring string, args ...interface{}) *Statement { - statement.WhereStr = querystring - statement.Params = args - return statement + statement.WhereStr = querystring + statement.Params = args + return statement } // add Where & and statment func (statement *Statement) And(querystring string, args ...interface{}) *Statement { - if statement.WhereStr != "" { - statement.WhereStr = fmt.Sprintf("(%v) AND (%v)", statement.WhereStr, querystring) - } else { - statement.WhereStr = querystring - } - statement.Params = append(statement.Params, args...) - return statement + if statement.WhereStr != "" { + statement.WhereStr = fmt.Sprintf("(%v) AND (%v)", statement.WhereStr, querystring) + } else { + statement.WhereStr = querystring + } + statement.Params = append(statement.Params, args...) + return statement } // add Where & Or statment func (statement *Statement) Or(querystring string, args ...interface{}) *Statement { - if statement.WhereStr != "" { - statement.WhereStr = fmt.Sprintf("(%v) OR (%v)", statement.WhereStr, querystring) - } else { - statement.WhereStr = querystring - } - statement.Params = append(statement.Params, args...) - return statement + if statement.WhereStr != "" { + statement.WhereStr = fmt.Sprintf("(%v) OR (%v)", statement.WhereStr, querystring) + } else { + statement.WhereStr = querystring + } + statement.Params = append(statement.Params, args...) + return statement } // tempororily set table name func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - t := rType(tableNameOrBean) - if t.Kind() == reflect.String { - statement.AltTableName = tableNameOrBean.(string) - } else if t.Kind() == reflect.Struct { - statement.RefTable = statement.Engine.autoMapType(t) - } - return statement + t := rType(tableNameOrBean) + if t.Kind() == reflect.String { + statement.AltTableName = tableNameOrBean.(string) + } else if t.Kind() == reflect.Struct { + statement.RefTable = statement.Engine.autoMapType(t) + } + return statement } /*func (statement *Statement) genFields(bean interface{}) map[string]interface{} { @@ -259,580 +259,578 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Auto generating conditions according a struct func buildConditions(engine *Engine, table *Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, allUseBool bool, - boolColumnMap map[string]bool) ([]string, []interface{}) { + includeVersion bool, includeUpdated bool, includeNil bool, allUseBool bool, + boolColumnMap map[string]bool) ([]string, []interface{}) { - colNames := make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns { - if !includeVersion && col.IsVersion { - continue - } - if !includeUpdated && col.IsUpdated { - continue - } - fieldValue := col.ValueOf(bean) - fieldType := reflect.TypeOf(fieldValue.Interface()) + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + if !includeVersion && col.IsVersion { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + fieldValue := col.ValueOf(bean) + fieldType := reflect.TypeOf(fieldValue.Interface()) - requiredField := false - if fieldType.Kind() == reflect.Ptr { - if fieldValue.IsNil() { - if includeNil { - args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) - } - continue - } else if !fieldValue.IsValid() { - continue - } else { - // dereference ptr type to instance type - fieldValue = fieldValue.Elem() - fieldType = reflect.TypeOf(fieldValue.Interface()) - requiredField = true - } - } + requiredField := false + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else if _, ok := boolColumnMap[col.Name]; ok { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - var str string - if col.SQLType.Name == Time { - s := t.UTC().Format("2006-01-02 15:04:05") - val = s[11:19] - } else if col.SQLType.Name == Date { - str = t.Format("2006-01-02") - val = str - } else { - val = t - } - } else { - engine.autoMapType(fieldValue.Type()) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) - if pkField.Int() != 0 { - val = pkField.Interface() - } else { - continue - } - } else { - val = fieldValue.Interface() - } - } - case reflect.Array, reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() { - continue - } + var val interface{} + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else if _, ok := boolColumnMap[col.Name]; ok { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + var str string + if col.SQLType.Name == Time { + s := t.UTC().Format("2006-01-02 15:04:05") + val = s[11:19] + } else if col.SQLType.Name == Date { + str = t.Format("2006-01-02") + val = str + } else { + val = t + } + } else { + engine.autoMapType(fieldValue.Type()) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) + if pkField.Int() != 0 { + val = pkField.Interface() + } else { + continue + } + } else { + val = fieldValue.Interface() + } + } + case reflect.Array, reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() { + continue + } - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.LogSQL(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.LogSQL(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogSQL(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogSQL(err) + continue + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } - args = append(args, val) - colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) - } + args = append(args, val) + colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) + } - return colNames, args + return colNames, args } // return current tableName func (statement *Statement) TableName() string { - if statement.AltTableName != "" { - return statement.AltTableName - } + if statement.AltTableName != "" { + return statement.AltTableName + } - if statement.RefTable != nil { - return statement.RefTable.Name - } - return "" + if statement.RefTable != nil { + return statement.RefTable.Name + } + return "" } // Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" func (statement *Statement) Id(id interface{}) *Statement { - - idValue := reflect.ValueOf(id) - idType := reflect.TypeOf(idValue.Interface()) - switch idType { - case reflect.TypeOf(&PK{}): - if pkPtr, ok := (id).(*PK); ok { - statement.IdParam = pkPtr - } - case reflect.TypeOf(PK{}): - if pk, ok := (id).(PK); ok { - statement.IdParam = &pk - } - default: - // TODO treat as int primitve for now, need to handle type check - statement.IdParam = &PK{id} + idValue := reflect.ValueOf(id) + idType := reflect.TypeOf(idValue.Interface()) - // !nashtsai! REVIEW although it will be user's mistake if called Id() twice with - // different value and Id should be PK's field name, however, at this stage probably - // can't tell which table is gonna be used - // if statement.WhereStr == "" { - // statement.WhereStr = "(id)=?" - // statement.Params = []interface{}{id} - // } else { - // // TODO what if id param has already passed - // statement.WhereStr = statement.WhereStr + " AND (id)=?" - // statement.Params = append(statement.Params, id) - // } - } + switch idType { + case reflect.TypeOf(&PK{}): + if pkPtr, ok := (id).(*PK); ok { + statement.IdParam = pkPtr + } + case reflect.TypeOf(PK{}): + if pk, ok := (id).(PK); ok { + statement.IdParam = &pk + } + default: + // TODO treat as int primitve for now, need to handle type check + statement.IdParam = &PK{id} - // !nashtsai! perhaps no need to validate pk values' type just let sql complaint happen + // !nashtsai! REVIEW although it will be user's mistake if called Id() twice with + // different value and Id should be PK's field name, however, at this stage probably + // can't tell which table is gonna be used + // if statement.WhereStr == "" { + // statement.WhereStr = "(id)=?" + // statement.Params = []interface{}{id} + // } else { + // // TODO what if id param has already passed + // statement.WhereStr = statement.WhereStr + " AND (id)=?" + // statement.Params = append(statement.Params, id) + // } + } - return statement + // !nashtsai! perhaps no need to validate pk values' type just let sql complaint happen + + return statement } // Generate "Where column IN (?) " statment func (statement *Statement) In(column string, args ...interface{}) *Statement { - k := strings.ToLower(column) - if params, ok := statement.inColumns[k]; ok { - statement.inColumns[k] = append(params, args...) - } else { - statement.inColumns[k] = args - } - return statement + k := strings.ToLower(column) + if params, ok := statement.inColumns[k]; ok { + statement.inColumns[k] = append(params, args...) + } else { + statement.inColumns[k] = args + } + return statement } func (statement *Statement) genInSql() (string, []interface{}) { - if len(statement.inColumns) == 0 { - return "", []interface{}{} - } + if len(statement.inColumns) == 0 { + return "", []interface{}{} + } - inStrs := make([]string, 0, len(statement.inColumns)) - args := make([]interface{}, 0) - for column, params := range statement.inColumns { - inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", statement.Engine.Quote(column), - strings.Join(makeArray("?", len(params)), ","))) - args = append(args, params...) - } + inStrs := make([]string, 0, len(statement.inColumns)) + args := make([]interface{}, 0) + for column, params := range statement.inColumns { + inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", statement.Engine.Quote(column), + strings.Join(makeArray("?", len(params)), ","))) + args = append(args, params...) + } - if len(statement.inColumns) == 1 { - return inStrs[0], args - } - return fmt.Sprintf("(%v)", strings.Join(inStrs, " AND ")), args + if len(statement.inColumns) == 1 { + return inStrs[0], args + } + return fmt.Sprintf("(%v)", strings.Join(inStrs, " AND ")), args } func (statement *Statement) attachInSql() { - inSql, inArgs := statement.genInSql() - if len(inSql) > 0 { - if statement.ConditionStr != "" { - statement.ConditionStr += " AND " - } - statement.ConditionStr += inSql - statement.Params = append(statement.Params, inArgs...) - } + inSql, inArgs := statement.genInSql() + if len(inSql) > 0 { + if statement.ConditionStr != "" { + statement.ConditionStr += " AND " + } + statement.ConditionStr += inSql + statement.Params = append(statement.Params, inArgs...) + } } func col2NewCols(columns ...string) []string { - newColumns := make([]string, 0) - for _, col := range columns { - strings.Replace(col, "`", "", -1) - strings.Replace(col, `"`, "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - newColumns = append(newColumns, strings.TrimSpace(c)) - } - } - return newColumns + newColumns := make([]string, 0) + for _, col := range columns { + strings.Replace(col, "`", "", -1) + strings.Replace(col, `"`, "", -1) + ccols := strings.Split(col, ",") + for _, c := range ccols { + newColumns = append(newColumns, strings.TrimSpace(c)) + } + } + return newColumns } // Generate "Distince col1, col2 " statment func (statement *Statement) Distinct(columns ...string) *Statement { - statement.IsDistinct = true - statement.Cols(columns...) - return statement + statement.IsDistinct = true + statement.Cols(columns...) + return statement } // Generate "col1, col2" statement func (statement *Statement) Cols(columns ...string) *Statement { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.columnMap[nc] = true - } - statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) - return statement + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.columnMap[nc] = true + } + statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) + return statement } // indicates that use bool fields as update contents and query contiditions func (statement *Statement) UseBool(columns ...string) *Statement { - if len(columns) > 0 { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.boolColumnMap[nc] = true - } - } else { - statement.allUseBool = true - } - return statement + if len(columns) > 0 { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.boolColumnMap[nc] = true + } + } else { + statement.allUseBool = true + } + return statement } // do not use the columns func (statement *Statement) Omit(columns ...string) { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.columnMap[nc] = false - } - statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.columnMap[nc] = false + } + statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } // Generate LIMIT limit statement func (statement *Statement) Top(limit int) *Statement { - statement.Limit(limit) - return statement + statement.Limit(limit) + return statement } // Generate LIMIT start, limit statement func (statement *Statement) Limit(limit int, start ...int) *Statement { - statement.LimitN = limit - if len(start) > 0 { - statement.Start = start[0] - } - return statement + statement.LimitN = limit + if len(start) > 0 { + statement.Start = start[0] + } + return statement } // Generate "Order By order" statement func (statement *Statement) OrderBy(order string) *Statement { - statement.OrderStr = order - return statement + statement.OrderStr = order + return statement } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(join_operator, tablename, condition string) *Statement { - if statement.JoinStr != "" { - statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) - } else { - statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) - } - return statement + if statement.JoinStr != "" { + statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } else { + statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } + return statement } // Generate "Group By keys" statement func (statement *Statement) GroupBy(keys string) *Statement { - statement.GroupByStr = keys - return statement + statement.GroupByStr = keys + return statement } // Generate "Having conditions" statement func (statement *Statement) Having(conditions string) *Statement { - statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) - return statement + statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) + return statement } func (statement *Statement) genColumnStr() string { - table := statement.RefTable - colNames := make([]string, 0) - for _, col := range table.Columns { - if statement.OmitStr != "" { - if _, ok := statement.columnMap[col.Name]; ok { - continue - } - } - if col.MapType == ONLYTODB { - continue - } - colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name)) - } - return strings.Join(colNames, ", ") + table := statement.RefTable + colNames := make([]string, 0) + for _, col := range table.Columns { + if statement.OmitStr != "" { + if _, ok := statement.columnMap[col.Name]; ok { + continue + } + } + if col.MapType == ONLYTODB { + continue + } + colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name)) + } + return strings.Join(colNames, ", ") } func (statement *Statement) genCreateTableSQL() string { - sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " (" + sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " (" + pkList := []string{} - pkList := []string{} + for _, colName := range statement.RefTable.ColumnsSeq { + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey { + pkList = append(pkList, col.Name) + } + } - for _, colName := range statement.RefTable.ColumnsSeq { - col := statement.RefTable.Columns[colName] - if col.IsPrimaryKey { - pkList = append(pkList, col.Name) - } - } + statement.Engine.LogDebug("len:", len(pkList)) + for _, colName := range statement.RefTable.ColumnsSeq { + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey && len(pkList) == 1 { + sql += col.String(statement.Engine.dialect) + } else { + sql += col.stringNoPk(statement.Engine.dialect) + } + sql = strings.TrimSpace(sql) + sql += ", " + } - statement.Engine.LogDebug("len:", len(pkList)) - for _, colName := range statement.RefTable.ColumnsSeq { - col := statement.RefTable.Columns[colName] - if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(statement.Engine.dialect) - } else { - sql += col.stringNoPk(statement.Engine.dialect) - } - sql = strings.TrimSpace(sql) - sql += ", " - } + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += strings.Join(pkList, ",") + sql += " ), " + } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += strings.Join(pkList, ",") - sql += " ), " - } - - sql = sql[:len(sql)-2] + ")" - if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" { - sql += " ENGINE=" + statement.StoreEngine - } - if statement.Engine.dialect.SupportCharset() && statement.Charset != "" { - sql += " DEFAULT CHARSET " + statement.Charset - } - sql += ";" - return sql + sql = sql[:len(sql)-2] + ")" + if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" { + sql += " ENGINE=" + statement.StoreEngine + } + if statement.Engine.dialect.SupportCharset() && statement.Charset != "" { + sql += " DEFAULT CHARSET " + statement.Charset + } + sql += ";" + return sql } func indexName(tableName, idxName string) string { - return fmt.Sprintf("IDX_%v_%v", tableName, idxName) + return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } func (s *Statement) genIndexSQL() []string { - var sqls []string = make([]string, 0) - tbName := s.TableName() - quote := s.Engine.Quote - for idxName, index := range s.RefTable.Indexes { - if index.Type == IndexType { - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(",")))) - sqls = append(sqls, sql) - } - } - return sqls + var sqls []string = make([]string, 0) + tbName := s.TableName() + quote := s.Engine.Quote + for idxName, index := range s.RefTable.Indexes { + if index.Type == IndexType { + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), + quote(tbName), quote(strings.Join(index.Cols, quote(",")))) + sqls = append(sqls, sql) + } + } + return sqls } func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) + return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) } func (s *Statement) genUniqueSQL() []string { - var sqls []string = make([]string, 0) - tbName := s.TableName() - quote := s.Engine.Quote - for idxName, unique := range s.RefTable.Indexes { - if unique.Type == UniqueType { - sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)), - quote(tbName), quote(strings.Join(unique.Cols, quote(",")))) - sqls = append(sqls, sql) - } - } - return sqls + var sqls []string = make([]string, 0) + tbName := s.TableName() + quote := s.Engine.Quote + for idxName, unique := range s.RefTable.Indexes { + if unique.Type == UniqueType { + sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)), + quote(tbName), quote(strings.Join(unique.Cols, quote(",")))) + sqls = append(sqls, sql) + } + } + return sqls } func (s *Statement) genDelIndexSQL() []string { - var sqls []string = make([]string, 0) - for idxName, index := range s.RefTable.Indexes { - var rIdxName string - if index.Type == UniqueType { - rIdxName = uniqueName(s.TableName(), idxName) - } else if index.Type == IndexType { - rIdxName = indexName(s.TableName(), idxName) - } - sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName)) - if s.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName())) - } - sqls = append(sqls, sql) - } - return sqls + var sqls []string = make([]string, 0) + for idxName, index := range s.RefTable.Indexes { + var rIdxName string + if index.Type == UniqueType { + rIdxName = uniqueName(s.TableName(), idxName) + } else if index.Type == IndexType { + rIdxName = indexName(s.TableName(), idxName) + } + sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName)) + if s.Engine.dialect.IndexOnTable() { + sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName())) + } + sqls = append(sqls, sql) + } + return sqls } func (s *Statement) genDropSQL() string { - sql := "DROP TABLE IF EXISTS " + s.Engine.Quote(s.TableName()) + ";" - return sql + sql := "DROP TABLE IF EXISTS " + s.Engine.Quote(s.TableName()) + ";" + return sql } func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { - table := statement.Engine.autoMap(bean) - statement.RefTable = table + table := statement.Engine.autoMap(bean) + statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true, true, - false, statement.allUseBool, statement.boolColumnMap) + colNames, args := buildConditions(statement.Engine, table, bean, true, true, + false, statement.allUseBool, statement.boolColumnMap) - statement.ConditionStr = strings.Join(colNames, " AND ") - statement.BeanArgs = args + statement.ConditionStr = strings.Join(colNames, " AND ") + statement.BeanArgs = args - var columnStr string = statement.ColumnStr - if columnStr == "" { - columnStr = statement.genColumnStr() - } + var columnStr string = statement.ColumnStr + if columnStr == "" { + columnStr = statement.genColumnStr() + } - return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...) + return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...) } func (s *Statement) genAddColumnStr(col *Column) (string, []interface{}) { - quote := s.Engine.Quote - sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()), - col.String(s.Engine.dialect)) - return sql, []interface{}{} + quote := s.Engine.Quote + sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()), + col.String(s.Engine.dialect)) + return sql, []interface{}{} } func (s *Statement) genAddIndexStr(idxName string, cols []string) (string, []interface{}) { - quote := s.Engine.Quote - colstr := quote(strings.Join(cols, quote(", "))) - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(idxName), quote(s.TableName()), colstr) - return sql, []interface{}{} + quote := s.Engine.Quote + colstr := quote(strings.Join(cols, quote(", "))) + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(idxName), quote(s.TableName()), colstr) + return sql, []interface{}{} } func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []interface{}) { - quote := s.Engine.Quote - colstr := quote(strings.Join(cols, quote(", "))) - sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uqeName), quote(s.TableName()), colstr) - return sql, []interface{}{} + quote := s.Engine.Quote + colstr := quote(strings.Join(cols, quote(", "))) + sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uqeName), quote(s.TableName()), colstr) + return sql, []interface{}{} } func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { - table := statement.Engine.autoMap(bean) - statement.RefTable = table + table := statement.Engine.autoMap(bean) + statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - statement.allUseBool, statement.boolColumnMap) + colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, + statement.allUseBool, statement.boolColumnMap) - statement.ConditionStr = strings.Join(colNames, " AND ") - statement.BeanArgs = args - var id string = "*" - if table.PrimaryKey != "" { - id = statement.Engine.Quote(table.PrimaryKey) - } - return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) + statement.ConditionStr = strings.Join(colNames, " AND ") + statement.BeanArgs = args + var id string = "*" + if table.PrimaryKey != "" { + id = statement.Engine.Quote(table.PrimaryKey) + } + return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) } func (statement *Statement) genSelectSql(columnStr string) (a string) { - if statement.GroupByStr != "" { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - statement.GroupByStr = columnStr - } - var distinct string - if statement.IsDistinct { - distinct = "DISTINCT " - } + if statement.GroupByStr != "" { + columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + statement.GroupByStr = columnStr + } + var distinct string + if statement.IsDistinct { + distinct = "DISTINCT " + } - // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr, - statement.Engine.Quote(statement.TableName())) - if statement.JoinStr != "" { - a = fmt.Sprintf("%v %v", a, statement.JoinStr) - } - statement.processIdParam() - if statement.WhereStr != "" { - a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) - if statement.ConditionStr != "" { - a = fmt.Sprintf("%v AND %v", a, statement.ConditionStr) - } - } else if statement.ConditionStr != "" { - a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr) - } + // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern + a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr, + statement.Engine.Quote(statement.TableName())) + if statement.JoinStr != "" { + a = fmt.Sprintf("%v %v", a, statement.JoinStr) + } + statement.processIdParam() + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + if statement.ConditionStr != "" { + a = fmt.Sprintf("%v AND %v", a, statement.ConditionStr) + } + } else if statement.ConditionStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr) + } - if statement.GroupByStr != "" { - a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) - } - if statement.HavingStr != "" { - a = fmt.Sprintf("%v %v", a, statement.HavingStr) - } - if statement.OrderStr != "" { - a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) - } - if statement.Start > 0 { - a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) - } - return + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + if statement.Start > 0 { + a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) + } else if statement.LimitN > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) + } + return } func (statement *Statement) processIdParam() { - if statement.IdParam != nil { - i := 0 - colCnt := len(statement.RefTable.ColumnsSeq) - for _, elem := range *(statement.IdParam) { - for ; i < colCnt; i++ { - colName := statement.RefTable.ColumnsSeq[i] - col := statement.RefTable.Columns[colName] - if col.IsPrimaryKey { - statement.And(fmt.Sprintf("%v=?", col.Name), elem) - i++ - break - } - } - } + if statement.IdParam != nil { + i := 0 + colCnt := len(statement.RefTable.ColumnsSeq) + for _, elem := range *(statement.IdParam) { + for ; i < colCnt; i++ { + colName := statement.RefTable.ColumnsSeq[i] + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey { + statement.And(fmt.Sprintf("%v=?", col.Name), elem) + i++ + break + } + } + } - // !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it - // as empty string for now, so this will result sql exec failed instead of unexpected - // false update/delete - for ; i < colCnt; i++ { - colName := statement.RefTable.ColumnsSeq[i] - col := statement.RefTable.Columns[colName] - if col.IsPrimaryKey { - statement.And(fmt.Sprintf("%v=?", col.Name), "") - } - } - } + // !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it + // as empty string for now, so this will result sql exec failed instead of unexpected + // false update/delete + for ; i < colCnt; i++ { + colName := statement.RefTable.ColumnsSeq[i] + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey { + statement.And(fmt.Sprintf("%v=?", col.Name), "") + } + } + } } - diff --git a/table.go b/table.go index e408ea85..976e30f9 100644 --- a/table.go +++ b/table.go @@ -1,335 +1,335 @@ package xorm import ( - "reflect" - "sort" - "strings" - "time" + "reflect" + "sort" + "strings" + "time" ) // xorm SQL types type SQLType struct { - Name string - DefaultLength int - DefaultLength2 int + Name string + DefaultLength int + DefaultLength2 int } func (s *SQLType) IsText() bool { - return s.Name == Char || s.Name == Varchar || s.Name == TinyText || - s.Name == Text || s.Name == MediumText || s.Name == LongText + return s.Name == Char || s.Name == Varchar || s.Name == TinyText || + s.Name == Text || s.Name == MediumText || s.Name == LongText } func (s *SQLType) IsBlob() bool { - return (s.Name == TinyBlob) || (s.Name == Blob) || - s.Name == MediumBlob || s.Name == LongBlob || - s.Name == Binary || s.Name == VarBinary || s.Name == Bytea + return (s.Name == TinyBlob) || (s.Name == Blob) || + s.Name == MediumBlob || s.Name == LongBlob || + s.Name == Binary || s.Name == VarBinary || s.Name == Bytea } const () var ( - Bit = "BIT" - TinyInt = "TINYINT" - SmallInt = "SMALLINT" - MediumInt = "MEDIUMINT" - Int = "INT" - Integer = "INTEGER" - BigInt = "BIGINT" + Bit = "BIT" + TinyInt = "TINYINT" + SmallInt = "SMALLINT" + MediumInt = "MEDIUMINT" + Int = "INT" + Integer = "INTEGER" + BigInt = "BIGINT" - Char = "CHAR" - Varchar = "VARCHAR" - TinyText = "TINYTEXT" - Text = "TEXT" - MediumText = "MEDIUMTEXT" - LongText = "LONGTEXT" - Binary = "BINARY" - VarBinary = "VARBINARY" + Char = "CHAR" + Varchar = "VARCHAR" + TinyText = "TINYTEXT" + Text = "TEXT" + MediumText = "MEDIUMTEXT" + LongText = "LONGTEXT" - Date = "DATE" - DateTime = "DATETIME" - Time = "TIME" - TimeStamp = "TIMESTAMP" - TimeStampz = "TIMESTAMPZ" + Date = "DATE" + DateTime = "DATETIME" + Time = "TIME" + TimeStamp = "TIMESTAMP" + TimeStampz = "TIMESTAMPZ" - Decimal = "DECIMAL" - Numeric = "NUMERIC" + Decimal = "DECIMAL" + Numeric = "NUMERIC" - Real = "REAL" - Float = "FLOAT" - Double = "DOUBLE" + Real = "REAL" + Float = "FLOAT" + Double = "DOUBLE" - TinyBlob = "TINYBLOB" - Blob = "BLOB" - MediumBlob = "MEDIUMBLOB" - LongBlob = "LONGBLOB" - Bytea = "BYTEA" + Binary = "BINARY" + VarBinary = "VARBINARY" + TinyBlob = "TINYBLOB" + Blob = "BLOB" + MediumBlob = "MEDIUMBLOB" + LongBlob = "LONGBLOB" + Bytea = "BYTEA" - Bool = "BOOL" + Bool = "BOOL" - Serial = "SERIAL" - BigSerial = "BIGSERIAL" + Serial = "SERIAL" + BigSerial = "BIGSERIAL" - sqlTypes = map[string]bool{ - Bit: true, - TinyInt: true, - SmallInt: true, - MediumInt: true, - Int: true, - Integer: true, - BigInt: true, + sqlTypes = map[string]bool{ + Bit: true, + TinyInt: true, + SmallInt: true, + MediumInt: true, + Int: true, + Integer: true, + BigInt: true, - Char: true, - Varchar: true, - TinyText: true, - Text: true, - MediumText: true, - LongText: true, - Binary: true, - VarBinary: true, + Char: true, + Varchar: true, + TinyText: true, + Text: true, + MediumText: true, + LongText: true, - Date: true, - DateTime: true, - Time: true, - TimeStamp: true, - TimeStampz: true, + Date: true, + DateTime: true, + Time: true, + TimeStamp: true, + TimeStampz: true, - Decimal: true, - Numeric: true, + Decimal: true, + Numeric: true, - Real: true, - Float: true, - Double: true, - TinyBlob: true, - Blob: true, - MediumBlob: true, - LongBlob: true, - Bytea: true, + Binary: true, + VarBinary: true, + Real: true, + Float: true, + Double: true, + TinyBlob: true, + Blob: true, + MediumBlob: true, + LongBlob: true, + Bytea: true, - Bool: true, + Bool: true, - Serial: true, - BigSerial: true, - } + Serial: true, + BigSerial: true, + } - intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} - uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} + intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} + uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} ) var b byte var tm time.Time func Type2SQLType(t reflect.Type) (st SQLType) { - switch k := t.Kind(); k { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - st = SQLType{Int, 0, 0} - case reflect.Int64, reflect.Uint64: - st = SQLType{BigInt, 0, 0} - case reflect.Float32: - st = SQLType{Float, 0, 0} - case reflect.Float64: - st = SQLType{Double, 0, 0} - case reflect.Complex64, reflect.Complex128: - st = SQLType{Varchar, 64, 0} - case reflect.Array, reflect.Slice, reflect.Map: - if t.Elem() == reflect.TypeOf(b) { - st = SQLType{Blob, 0, 0} - } else { - st = SQLType{Text, 0, 0} - } - case reflect.Bool: - st = SQLType{Bool, 0, 0} - case reflect.String: - st = SQLType{Varchar, 255, 0} - case reflect.Struct: - if t == reflect.TypeOf(tm) { - st = SQLType{DateTime, 0, 0} - } else { - st = SQLType{Text, 0, 0} - } - case reflect.Ptr: - st, _ = ptrType2SQLType(t) - default: - st = SQLType{Text, 0, 0} - } - return + switch k := t.Kind(); k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + st = SQLType{Int, 0, 0} + case reflect.Int64, reflect.Uint64: + st = SQLType{BigInt, 0, 0} + case reflect.Float32: + st = SQLType{Float, 0, 0} + case reflect.Float64: + st = SQLType{Double, 0, 0} + case reflect.Complex64, reflect.Complex128: + st = SQLType{Varchar, 64, 0} + case reflect.Array, reflect.Slice, reflect.Map: + if t.Elem() == reflect.TypeOf(b) { + st = SQLType{Blob, 0, 0} + } else { + st = SQLType{Text, 0, 0} + } + case reflect.Bool: + st = SQLType{Bool, 0, 0} + case reflect.String: + st = SQLType{Varchar, 255, 0} + case reflect.Struct: + if t == reflect.TypeOf(tm) { + st = SQLType{DateTime, 0, 0} + } else { + st = SQLType{Text, 0, 0} + } + case reflect.Ptr: + st, _ = ptrType2SQLType(t) + default: + st = SQLType{Text, 0, 0} + } + return } func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { - has = true + has = true - switch t { - case reflect.TypeOf(&c_EMPTY_STRING): - st = SQLType{Varchar, 255, 0} - return - case reflect.TypeOf(&c_BOOL_DEFAULT): - st = SQLType{Bool, 0, 0} - case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT): - st = SQLType{Varchar, 64, 0} - case reflect.TypeOf(&c_FLOAT32_DEFAULT): - st = SQLType{Float, 0, 0} - case reflect.TypeOf(&c_FLOAT64_DEFAULT): - st = SQLType{Double, 0, 0} - case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT): - st = SQLType{BigInt, 0, 0} - case reflect.TypeOf(&c_TIME_DEFAULT): - st = SQLType{DateTime, 0, 0} - case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT): - st = SQLType{Int, 0, 0} - default: - has = false - } - return + switch t { + case reflect.TypeOf(&c_EMPTY_STRING): + st = SQLType{Varchar, 255, 0} + return + case reflect.TypeOf(&c_BOOL_DEFAULT): + st = SQLType{Bool, 0, 0} + case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT): + st = SQLType{Varchar, 64, 0} + case reflect.TypeOf(&c_FLOAT32_DEFAULT): + st = SQLType{Float, 0, 0} + case reflect.TypeOf(&c_FLOAT64_DEFAULT): + st = SQLType{Double, 0, 0} + case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT): + st = SQLType{BigInt, 0, 0} + case reflect.TypeOf(&c_TIME_DEFAULT): + st = SQLType{DateTime, 0, 0} + case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT): + st = SQLType{Int, 0, 0} + default: + has = false + } + return } // default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { - name := strings.ToUpper(st.Name) - switch name { - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: - return reflect.TypeOf(1) - case BigInt, BigSerial: - return reflect.TypeOf(int64(1)) - case Float, Real: - return reflect.TypeOf(float32(1)) - case Double: - return reflect.TypeOf(float64(1)) - case Char, Varchar, TinyText, Text, MediumText, LongText: - return reflect.TypeOf("") - case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: - return reflect.TypeOf([]byte{}) - case Bool: - return reflect.TypeOf(true) - case DateTime, Date, Time, TimeStamp, TimeStampz: - return reflect.TypeOf(tm) - case Decimal, Numeric: - return reflect.TypeOf("") - default: - return reflect.TypeOf("") - } + name := strings.ToUpper(st.Name) + switch name { + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: + return reflect.TypeOf(1) + case BigInt, BigSerial: + return reflect.TypeOf(int64(1)) + case Float, Real: + return reflect.TypeOf(float32(1)) + case Double: + return reflect.TypeOf(float64(1)) + case Char, Varchar, TinyText, Text, MediumText, LongText: + return reflect.TypeOf("") + case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: + return reflect.TypeOf([]byte{}) + case Bool: + return reflect.TypeOf(true) + case DateTime, Date, Time, TimeStamp, TimeStampz: + return reflect.TypeOf(tm) + case Decimal, Numeric: + return reflect.TypeOf("") + default: + return reflect.TypeOf("") + } } const ( - IndexType = iota + 1 - UniqueType + IndexType = iota + 1 + UniqueType ) // database index type Index struct { - Name string - Type int - Cols []string + Name string + Type int + Cols []string } // add columns which will be composite index func (index *Index) AddColumn(cols ...string) { - for _, col := range cols { - index.Cols = append(index.Cols, col) - } + for _, col := range cols { + index.Cols = append(index.Cols, col) + } } // new an index func NewIndex(name string, indexType int) *Index { - return &Index{name, indexType, make([]string, 0)} + return &Index{name, indexType, make([]string, 0)} } const ( - TWOSIDES = iota + 1 - ONLYTODB - ONLYFROMDB + TWOSIDES = iota + 1 + ONLYTODB + ONLYFROMDB ) // database column type Column struct { - Name string - FieldName string - SQLType SQLType - Length int - Length2 int - Nullable bool - Default string - Indexes map[string]bool - IsPrimaryKey bool - IsAutoIncrement bool - MapType int - IsCreated bool - IsUpdated bool - IsCascade bool - IsVersion bool + Name string + FieldName string + SQLType SQLType + Length int + Length2 int + Nullable bool + Default string + Indexes map[string]bool + IsPrimaryKey bool + IsAutoIncrement bool + MapType int + IsCreated bool + IsUpdated bool + IsCascade bool + IsVersion bool } // generate column description string according dialect func (col *Column) String(d dialect) string { - sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " + sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " - sql += d.SqlType(col) + " " + sql += d.SqlType(col) + " " - if col.IsPrimaryKey { - sql += "PRIMARY KEY " - if col.IsAutoIncrement { - sql += d.AutoIncrStr() + " " - } - } + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + if col.IsAutoIncrement { + sql += d.AutoIncrStr() + " " + } + } - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } - if col.Default != "" { - sql += "DEFAULT " + col.Default + " " - } + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } - return sql + return sql } func (col *Column) stringNoPk(d dialect) string { - sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " + sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " - sql += d.SqlType(col) + " " + sql += d.SqlType(col) + " " - if col.Nullable { - sql += "NULL " - } else { - sql += "NOT NULL " - } + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } - if col.Default != "" { - sql += "DEFAULT " + col.Default + " " - } + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } - return sql + return sql } // return col's filed of struct's value func (col *Column) ValueOf(bean interface{}) reflect.Value { - var fieldValue reflect.Value - if strings.Contains(col.FieldName, ".") { - fields := strings.Split(col.FieldName, ".") - if len(fields) > 2 { - return reflect.ValueOf(nil) - } + var fieldValue reflect.Value + if strings.Contains(col.FieldName, ".") { + fields := strings.Split(col.FieldName, ".") + if len(fields) > 2 { + return reflect.ValueOf(nil) + } - fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0]) - fieldValue = fieldValue.FieldByName(fields[1]) - } else { - fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) - } - return fieldValue + fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0]) + fieldValue = fieldValue.FieldByName(fields[1]) + } else { + fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + } + return fieldValue } // database table type Table struct { - Name string - Type reflect.Type - ColumnsSeq []string - Columns map[string]*Column - Indexes map[string]*Index - PrimaryKey string - Created map[string]bool - Updated string - Version string - Cacher Cacher + Name string + Type reflect.Type + ColumnsSeq []string + Columns map[string]*Column + Indexes map[string]*Index + PrimaryKey string + Created map[string]bool + Updated string + Version string + Cacher Cacher } /* @@ -344,90 +344,90 @@ func NewTable(name string, t reflect.Type) *Table { // if has primary key, return column func (table *Table) PKColumn() *Column { - return table.Columns[table.PrimaryKey] + return table.Columns[table.PrimaryKey] } func (table *Table) VersionColumn() *Column { - return table.Columns[table.Version] + return table.Columns[table.Version] } // add a column to table func (table *Table) AddColumn(col *Column) { - table.ColumnsSeq = append(table.ColumnsSeq, col.Name) - table.Columns[col.Name] = col - if col.IsPrimaryKey { - table.PrimaryKey = col.Name - } - if col.IsCreated { - table.Created[col.Name] = true - } - if col.IsUpdated { - table.Updated = col.Name - } - if col.IsVersion { - table.Version = col.Name - } + table.ColumnsSeq = append(table.ColumnsSeq, col.Name) + table.Columns[col.Name] = col + if col.IsPrimaryKey { + table.PrimaryKey = col.Name + } + if col.IsCreated { + table.Created[col.Name] = true + } + if col.IsUpdated { + table.Updated = col.Name + } + if col.IsVersion { + table.Version = col.Name + } } // add an index or an unique to table func (table *Table) AddIndex(index *Index) { - table.Indexes[index.Name] = index + table.Indexes[index.Name] = index } func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { - colNames := make([]string, 0) - args := make([]interface{}, 0) + colNames := make([]string, 0) + args := make([]interface{}, 0) - for _, col := range table.Columns { - if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { - if _, ok := session.Statement.columnMap[col.Name]; !ok { - continue - } - } - if col.MapType == ONLYFROMDB { - continue - } + for _, col := range table.Columns { + if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { + if _, ok := session.Statement.columnMap[col.Name]; !ok { + continue + } + } + if col.MapType == ONLYFROMDB { + continue + } - fieldValue := col.ValueOf(bean) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue - } + fieldValue := col.ValueOf(bean) + if col.IsAutoIncrement && fieldValue.Int() == 0 { + continue + } - if session.Statement.ColumnStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; !ok { - continue - } - } - if session.Statement.OmitStr != "" { - if _, ok := session.Statement.columnMap[col.Name]; ok { - continue - } - } + if session.Statement.ColumnStr != "" { + if _, ok := session.Statement.columnMap[col.Name]; !ok { + continue + } + } + if session.Statement.OmitStr != "" { + if _, ok := session.Statement.columnMap[col.Name]; ok { + continue + } + } - if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { - args = append(args, time.Now()) - } else if col.IsVersion && session.Statement.checkVersion { - args = append(args, 1) - } else { - arg, err := session.value2Interface(col, fieldValue) - if err != nil { - return colNames, args, err - } - args = append(args, arg) - } + if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { + args = append(args, time.Now()) + } else if col.IsVersion && session.Statement.checkVersion { + args = append(args, 1) + } else { + arg, err := session.value2Interface(col, fieldValue) + if err != nil { + return colNames, args, err + } + args = append(args, arg) + } - if includeQuote { - colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") - } else { - colNames = append(colNames, col.Name) - } - } - return colNames, args, nil + if includeQuote { + colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?") + } else { + colNames = append(colNames, col.Name) + } + } + return colNames, args, nil } // Conversion is an interface. A type implements Conversion will according // the custom method to fill into database and retrieve from database. type Conversion interface { - FromDB([]byte) error - ToDB() ([]byte, error) + FromDB([]byte) error + ToDB() ([]byte, error) } diff --git a/xorm.go b/xorm.go index 060589d4..23a4a69b 100644 --- a/xorm.go +++ b/xorm.go @@ -1,58 +1,58 @@ package xorm import ( - "errors" - "fmt" - "os" - "reflect" - "runtime" - "sync" + "errors" + "fmt" + "os" + "reflect" + "runtime" + "sync" ) const ( - version string = "0.2.2" + version string = "0.2.3" ) func close(engine *Engine) { - engine.Close() + engine.Close() } // new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { - engine := &Engine{DriverName: driverName, - DataSourceName: dataSourceName, Filters: make([]Filter, 0)} - engine.SetMapper(SnakeMapper{}) + engine := &Engine{DriverName: driverName, + DataSourceName: dataSourceName, Filters: make([]Filter, 0)} + engine.SetMapper(SnakeMapper{}) - if driverName == SQLITE { - engine.dialect = &sqlite3{} - } else if driverName == MYSQL { - engine.dialect = &mysql{} - } else if driverName == POSTGRES { - engine.dialect = &postgres{} - engine.Filters = append(engine.Filters, &PgSeqFilter{}) - engine.Filters = append(engine.Filters, &QuoteFilter{}) - } else if driverName == MYMYSQL { - engine.dialect = &mymysql{} - } else { - return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) - } - err := engine.dialect.Init(driverName, dataSourceName) - if err != nil { - return nil, err - } + if driverName == SQLITE { + engine.dialect = &sqlite3{} + } else if driverName == MYSQL { + engine.dialect = &mysql{} + } else if driverName == POSTGRES { + engine.dialect = &postgres{} + engine.Filters = append(engine.Filters, &PgSeqFilter{}) + engine.Filters = append(engine.Filters, &QuoteFilter{}) + } else if driverName == MYMYSQL { + engine.dialect = &mymysql{} + } else { + return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) + } + err := engine.dialect.Init(driverName, dataSourceName) + if err != nil { + return nil, err + } - engine.Tables = make(map[reflect.Type]*Table) - engine.mutex = &sync.Mutex{} - engine.TagIdentifier = "xorm" + engine.Tables = make(map[reflect.Type]*Table) + engine.mutex = &sync.Mutex{} + engine.TagIdentifier = "xorm" - engine.Filters = append(engine.Filters, &IdFilter{}) - engine.Logger = os.Stdout + engine.Filters = append(engine.Filters, &IdFilter{}) + engine.Logger = os.Stdout - //engine.Pool = NewSimpleConnectPool() - //engine.Pool = NewNoneConnectPool() - //engine.Cacher = NewLRUCacher() - err = engine.SetPool(NewSysConnectPool()) - runtime.SetFinalizer(engine, close) - return engine, err + //engine.Pool = NewSimpleConnectPool() + //engine.Pool = NewNoneConnectPool() + //engine.Cacher = NewLRUCacher() + err = engine.SetPool(NewSysConnectPool()) + runtime.SetFinalizer(engine, close) + return engine, err }