fix some bugs

This commit is contained in:
Lunny Xiao 2019-12-12 22:36:21 +08:00
parent 817dbe4f61
commit db6c12b3d0
3 changed files with 49 additions and 28 deletions

View File

@ -34,28 +34,12 @@ func (db *db2) SqlType(c *core.Column) string {
case core.Bit: case core.Bit:
res = core.Boolean res = core.Boolean
return res return res
case core.MediumInt, core.Int, core.Integer:
if c.IsAutoIncrement {
return core.Serial
}
return core.Integer
case core.BigInt:
if c.IsAutoIncrement {
return core.BigSerial
}
return core.BigInt
case core.Serial, core.BigSerial:
c.IsAutoIncrement = true
c.Nullable = false
res = t
case core.Binary, core.VarBinary: case core.Binary, core.VarBinary:
return core.Bytea return core.Bytea
case core.DateTime: case core.DateTime:
res = core.TimeStamp res = core.TimeStamp
case core.TimeStampz: case core.TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case core.Float:
res = core.Real
case core.TinyText, core.MediumText, core.LongText: case core.TinyText, core.MediumText, core.LongText:
res = core.Text res = core.Text
case core.NVarchar: case core.NVarchar:
@ -64,12 +48,7 @@ func (db *db2) SqlType(c *core.Column) string {
return core.Uuid return core.Uuid
case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob:
return core.Bytea return core.Bytea
case core.Double:
return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement {
return core.Serial
}
res = t res = t
} }
@ -118,6 +97,37 @@ func (db *db2) IndexOnTable() bool {
return false return false
} }
func (db *db2) CreateTableSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
sql += db.Quote(tableName) + " ("
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
sql += col.StringNoPk(db)
if col.IsAutoIncrement {
sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )"
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += db.Quote(strings.Join(pkList, db.Quote(",")))
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
return sql
}
func (db *db2) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *db2) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
if len(db.Schema) == 0 { if len(db.Schema) == 0 {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
@ -298,10 +308,10 @@ where t.type = 'T' AND c.tabname = ?`
func (db *db2) GetTables() ([]*core.Table, error) { func (db *db2) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT NAME FROM SYSIBM.SYSTABLES WHERE type = 'T'" s := "SELECT TABNAME FROM SYSCAT.TABLES WHERE type = 'T' AND OWNERTYPE = 'U'"
if len(db.Schema) != 0 { if len(db.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.Schema)
s = s + " AND creator = ?" s = s + " AND TABSCHEMA = ?"
} }
db.LogSQL(s, args) db.LogSQL(s, args)
@ -392,6 +402,7 @@ type db2Driver struct{}
func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) { func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
var dbName string var dbName string
var defaultSchema string
kv := strings.Split(dataSourceName, ";") kv := strings.Split(dataSourceName, ";")
for _, c := range kv { for _, c := range kv {
@ -400,6 +411,8 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error)
switch strings.ToLower(vv[0]) { switch strings.ToLower(vv[0]) {
case "database": case "database":
dbName = vv[1] dbName = vv[1]
case "uid":
defaultSchema = vv[1]
} }
} }
} }
@ -407,5 +420,9 @@ func (p *db2Driver) Parse(driverName, dataSourceName string) (*core.Uri, error)
if dbName == "" { if dbName == "" {
return nil, errors.New("no db name provided") return nil, errors.New("no db name provided")
} }
return &core.Uri{DbName: dbName, DbType: "db2"}, nil return &core.Uri{
DbName: dbName,
DbType: "db2",
Schema: defaultSchema,
}, nil
} }

View File

@ -62,7 +62,7 @@ func TestCacheFind(t *testing.T) {
} }
boxes = make([]MailBox, 0, 2) boxes = make([]MailBox, 0, 2)
assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes)) assert.NoError(t, testEngine.Alias("a").Where("`a`.`id` > -1").Asc("a.id").Find(&boxes))
assert.EqualValues(t, 2, len(boxes)) assert.EqualValues(t, 2, len(boxes))
for i, box := range boxes { for i, box := range boxes {
assert.Equal(t, inserts[i].Id, box.Id) assert.Equal(t, inserts[i].Id, box.Id)
@ -77,7 +77,7 @@ func TestCacheFind(t *testing.T) {
} }
boxes2 := make([]MailBox4, 0, 2) boxes2 := make([]MailBox4, 0, 2)
assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2)) assert.NoError(t, testEngine.Table("mail_box").Where("`mail_box`.`id` > -1").Asc("mail_box.id").Find(&boxes2))
assert.EqualValues(t, 2, len(boxes2)) assert.EqualValues(t, 2, len(boxes2))
for i, box := range boxes2 { for i, box := range boxes2 {
assert.Equal(t, inserts[i].Id, box.Id) assert.Equal(t, inserts[i].Id, box.Id)
@ -164,14 +164,14 @@ func TestCacheGet(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
var box1 MailBox3 var box1 MailBox3
has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) has, err := testEngine.Where("`id` = ?", inserts[0].Id).Get(&box1)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, "user1", box1.Username) assert.EqualValues(t, "user1", box1.Username)
assert.EqualValues(t, "pass1", box1.Password) assert.EqualValues(t, "pass1", box1.Password)
var box2 MailBox3 var box2 MailBox3
has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) has, err = testEngine.Where("`id` = ?", inserts[0].Id).Get(&box2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.EqualValues(t, "user1", box2.Username) assert.EqualValues(t, "user1", box2.Username)

View File

@ -235,6 +235,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil { if err != nil {
fmt.Println("------", tables, err)
return err return err
} }
@ -244,6 +245,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
session.resetStatement() session.resetStatement()
}() }()
fmt.Println("-----", tables, len(tables), len(beans))
for _, bean := range beans { for _, bean := range beans {
v := utils.ReflectValue(bean) v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v) table, err := engine.tagParser.ParseWithCache(v)
@ -260,6 +263,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
var oriTable *schemas.Table var oriTable *schemas.Table
for _, tb := range tables { for _, tb := range tables {
fmt.Println("----", tb.Name, engine.tbNameWithSchema(tb.Name), "===", tbName, engine.tbNameWithSchema(tbName))
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb oriTable = tb
break break