From db6c12b3d0560123e96bbcba6a891a6ee2f8c4c0 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 12 Dec 2019 22:36:21 +0800 Subject: [PATCH] fix some bugs --- dialect_db2.go | 65 ++++++++++++++++++++++++-------------- integrations/cache_test.go | 8 ++--- session_schema.go | 4 +++ 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/dialect_db2.go b/dialect_db2.go index 4712f559..68a44242 100644 --- a/dialect_db2.go +++ b/dialect_db2.go @@ -34,28 +34,12 @@ func (db *db2) SqlType(c *core.Column) string { case core.Bit: res = core.Boolean return res - case core.MediumInt, core.Int, core.Integer: - if c.IsAutoIncrement { - return core.Serial - } - return core.Integer - case core.BigInt: - if c.IsAutoIncrement { - return core.BigSerial - } - return core.BigInt - case core.Serial, core.BigSerial: - c.IsAutoIncrement = true - c.Nullable = false - res = t case core.Binary, core.VarBinary: return core.Bytea case core.DateTime: res = core.TimeStamp case core.TimeStampz: return "timestamp with time zone" - case core.Float: - res = core.Real case core.TinyText, core.MediumText, core.LongText: res = core.Text case core.NVarchar: @@ -64,12 +48,7 @@ func (db *db2) SqlType(c *core.Column) string { return core.Uuid case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: return core.Bytea - case core.Double: - return "DOUBLE PRECISION" default: - if c.IsAutoIncrement { - return core.Serial - } res = t } @@ -118,6 +97,37 @@ func (db *db2) IndexOnTable() bool { return false } +func (db *db2) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string { + var sql string + sql = "CREATE TABLE " + if tableName == "" { + tableName = table.Name + } + + sql += db.Quote(tableName) + " (" + + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + sql += col.StringNoPk(db) + if col.IsAutoIncrement { + sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )" + } + sql = strings.TrimSpace(sql) + sql += ", " + } + + if len(pkList) > 0 { + sql += "PRIMARY KEY ( " + sql += db.Quote(strings.Join(pkList, db.Quote(","))) + sql += " ), " + } + + sql = sql[:len(sql)-2] + ")" + return sql +} + func (db *db2) IndexCheckSql(tableName, idxName string) (string, []interface{}) { if len(db.Schema) == 0 { args := []interface{}{tableName, idxName} @@ -298,10 +308,10 @@ where t.type = 'T' AND c.tabname = ?` func (db *db2) GetTables() ([]*core.Table, error) { args := []interface{}{} - s := "SELECT NAME FROM SYSIBM.SYSTABLES WHERE type = 'T'" + s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'" if len(db.Schema) != 0 { args = append(args, db.Schema) - s = s + " AND creator = ?" + s = s + " AND TABSCHEMA = ?" } db.LogSQL(s, args) @@ -392,6 +402,7 @@ type db2Driver struct{} func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { var dbName string + var defaultSchema string kv := strings.Split(dataSourceName, ";") for _, c := range kv { @@ -400,6 +411,8 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) switch strings.ToLower(vv[0]) { case "database": dbName = vv[1] + case "uid": + defaultSchema = vv[1] } } } @@ -407,5 +420,9 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) if dbName == "" { return nil, errors.New("no db name provided") } - return &core.Uri{DbName: dbName, DbType: "db2"}, nil + return &core.Uri{ + DbName: dbName, + DbType: "db2", + Schema: defaultSchema, + }, nil } diff --git a/integrations/cache_test.go b/integrations/cache_test.go index 44e817b1..71a37c1f 100644 --- a/integrations/cache_test.go +++ b/integrations/cache_test.go @@ -62,7 +62,7 @@ func TestCacheFind(t *testing.T) { } boxes = make([]MailBox, 0, 2) - assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes)) + assert.NoError(t, testEngine.Alias("a").Where("`a`.`id` > -1").Asc("a.id").Find(&boxes)) assert.EqualValues(t, 2, len(boxes)) for i, box := range boxes { assert.Equal(t, inserts[i].Id, box.Id) @@ -77,7 +77,7 @@ func TestCacheFind(t *testing.T) { } boxes2 := make([]MailBox4, 0, 2) - assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2)) + assert.NoError(t, testEngine.Table("mail_box").Where("`mail_box`.`id` > -1").Asc("mail_box.id").Find(&boxes2)) assert.EqualValues(t, 2, len(boxes2)) for i, box := range boxes2 { assert.Equal(t, inserts[i].Id, box.Id) @@ -164,14 +164,14 @@ func TestCacheGet(t *testing.T) { assert.NoError(t, err) var box1 MailBox3 - has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) + has, err := testEngine.Where("`id` = ?", inserts[0].Id).Get(&box1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box1.Username) assert.EqualValues(t, "pass1", box1.Password) var box2 MailBox3 - has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) + has, err = testEngine.Where("`id` = ?", inserts[0].Id).Get(&box2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box2.Username) diff --git a/session_schema.go b/session_schema.go index 2e64350f..7055e910 100644 --- a/session_schema.go +++ b/session_schema.go @@ -235,6 +235,7 @@ func (session *Session) Sync2(beans ...interface{}) error { tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) if err != nil { + fmt.Println("------", tables, err) return err } @@ -244,6 +245,8 @@ func (session *Session) Sync2(beans ...interface{}) error { session.resetStatement() }() + fmt.Println("-----", tables, len(tables), len(beans)) + for _, bean := range beans { v := utils.ReflectValue(bean) table, err := engine.tagParser.ParseWithCache(v) @@ -260,6 +263,7 @@ func (session *Session) Sync2(beans ...interface{}) error { var oriTable *schemas.Table for _, tb := range tables { + fmt.Println("----", tb.Name, engine.tbNameWithSchema(tb.Name), "===", tbName, engine.tbNameWithSchema(tbName)) if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { oriTable = tb break