Fix tests

This commit is contained in:
Lunny Xiao 2021-08-12 17:43:26 +08:00
parent e502385b12
commit 78d3504360
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 89 additions and 66 deletions

View File

@ -38,6 +38,12 @@ func (db *db2) Version(context.Context, core.Queryer) (*schemas.Version, error)
return nil, fmt.Errorf("not implementation") return nil, fmt.Errorf("not implementation")
} }
func (db *db2) Features() *DialectFeatures {
return &DialectFeatures{
DefaultClause: "WITH DEFAULT",
}
}
func (db *db2) ColumnTypeKind(t string) int { func (db *db2) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) { switch strings.ToUpper(t) {
case "DATE", "DATETIME", "DATETIME2", "TIME": case "DATE", "DATETIME", "DATETIME2", "TIME":
@ -58,9 +64,9 @@ func (db *db2) SQLType(c *schemas.Column) string {
res = schemas.SmallInt res = schemas.SmallInt
return res return res
case schemas.UnsignedBigInt: case schemas.UnsignedBigInt:
res = schemas.BigInt return schemas.BigInt
case schemas.UnsignedInt: case schemas.UnsignedInt:
res = schemas.BigInt return schemas.BigInt
case schemas.Bit, schemas.Bool, schemas.Boolean: case schemas.Bit, schemas.Bool, schemas.Boolean:
res = schemas.Boolean res = schemas.Boolean
return res return res
@ -119,35 +125,42 @@ func (db *db2) IndexOnTable() bool {
} }
func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string
sql = "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
sql += db.Quoter().Quote(tableName) + " (" quoter := db.Quoter()
var b strings.Builder
b.WriteString("CREATE TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
pkList := table.PrimaryKeys for i, colName := range table.ColumnsSeq() {
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db, col, false) if !col.DefaultIsEmpty {
sql += s col.Nullable = false
if col.IsAutoIncrement { }
sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )" s, _ := ColumnString(db, col, false)
b.WriteString(s)
if col.IsAutoIncrement {
b.WriteString(" GENERATED BY DEFAULT AS IDENTITY (START WITH 1, INCREMENT BY 1)")
}
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
} }
sql = strings.TrimSpace(sql)
sql += ", "
} }
if len(pkList) > 0 { if len(table.PrimaryKeys) > 0 {
sql += "PRIMARY KEY ( " b.WriteString(", PRIMARY KEY (")
sql += db.Quoter().Join(pkList, ",") b.WriteString(quoter.Join(table.PrimaryKeys, ","))
sql += " ), " b.WriteString(")")
} }
sql = sql[:len(sql)-2] + ")" b.WriteString(")")
return []string{sql}, false
return []string{b.String()}, true
} }
func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -32,17 +32,21 @@ type URI struct {
// SetSchema set schema // SetSchema set schema
func (uri *URI) SetSchema(schema string) { func (uri *URI) SetSchema(schema string) {
// hack me
if uri.DBType == schemas.POSTGRES { if uri.DBType == schemas.POSTGRES {
uri.Schema = strings.TrimSpace(schema) uri.Schema = strings.TrimSpace(schema)
} }
} }
type DialectFeatures struct {
DefaultClause string // default key word
}
// Dialect represents a kind of database // Dialect represents a kind of database
type Dialect interface { type Dialect interface {
Init(*URI) error Init(*URI) error
URI() *URI URI() *URI
Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error)
Features() *DialectFeatures
SQLType(*schemas.Column) string SQLType(*schemas.Column) string
Alias(string) string // return what a sql type's alias of Alias(string) string // return what a sql type's alias of
@ -103,6 +107,12 @@ func (db *Base) URI() *URI {
return db.uri return db.uri
} }
func (db *Base) Features() *DialectFeatures {
return &DialectFeatures{
DefaultClause: "DEFAULT",
}
}
// DropTableSQL returns drop table SQL // DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) { func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote quote := db.dialect.Quoter().Quote
@ -253,43 +263,46 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err return "", err
} }
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if includePrimaryKey && col.IsPrimaryKey { if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString("PRIMARY KEY "); err != nil { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err return "", err
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { if err := bd.WriteByte(' '); err != nil {
return "", err return "", err
} }
if err := bd.WriteByte(' '); err != nil { if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
return "", err return "", err
} }
} }
} }
if col.Default != "" { if col.Default != "" {
if _, err := bd.WriteString("DEFAULT "); err != nil { if err := bd.WriteByte(' '); err != nil {
return "", err return "", err
} }
if _, err := bd.WriteString(col.Default); err != nil { if _, err := bd.WriteString(dialect.Features().DefaultClause); err != nil {
return "", err return "", err
} }
if err := bd.WriteByte(' '); err != nil { if err := bd.WriteByte(' '); err != nil {
return "", err return "", err
} }
if _, err := bd.WriteString(col.Default); err != nil {
return "", err
}
}
if err := bd.WriteByte(' '); err != nil {
return "", err
} }
if col.Nullable { if col.Nullable {
if _, err := bd.WriteString("NULL "); err != nil { if _, err := bd.WriteString("NULL"); err != nil {
return "", err return "", err
} }
} else { } else {
if _, err := bd.WriteString("NOT NULL "); err != nil { if _, err := bd.WriteString("NOT NULL"); err != nil {
return "", err return "", err
} }
} }

View File

@ -966,38 +966,35 @@ func (db *postgres) AutoIncrStr() string {
} }
func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
quoter := db.Quoter() quoter := db.Quoter()
sql += quoter.Quote(tableName) var b strings.Builder
sql += " (" b.WriteString("CREATE TABLE IF NOT EXIST ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
if len(table.ColumnsSeq()) > 0 { for i, colName := range table.ColumnsSeq() {
pkList := table.PrimaryKeys col := table.GetColumn(colName)
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1)
b.WriteString(s)
for _, colName := range table.ColumnsSeq() { if len(table.PrimaryKeys) > 1 {
col := table.GetColumn(colName) b.WriteString("PRIMARY KEY ( ")
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) b.WriteString(quoter.Join(table.PrimaryKeys, ","))
sql += s b.WriteString(" )")
sql = strings.TrimSpace(sql)
sql += ", "
} }
if len(pkList) > 1 { if i != len(table.ColumnsSeq())-1 {
sql += "PRIMARY KEY ( " b.WriteString(", ")
sql += quoter.Join(pkList, ",")
sql += " ), "
} }
sql = sql[:len(sql)-2]
} }
sql += ")"
return []string{sql}, true b.WriteString(")")
return []string{b.String()}, false
} }
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -126,7 +126,7 @@ func TestDump(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, sess.Commit()) assert.NoError(t, sess.Commit())
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL, schemas.DB2} {
name := fmt.Sprintf("dump_%v.sql", tp) name := fmt.Sprintf("dump_%v.sql", tp)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.DumpAllToFile(name, tp)) assert.NoError(t, testEngine.DumpAllToFile(name, tp))

View File

@ -37,49 +37,49 @@ func TestBuilder(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
var cond Condition var cond Condition
has, err := testEngine.Where(builder.Eq{"col_name": "col1"}).Get(&cond) has, err := testEngine.Where(builder.Eq{"`col_name`": "col1"}).Get(&cond)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist") assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Eq{"op": OpEqual})). And(builder.Eq{"`op`": OpEqual})).
NoAutoCondition(). NoAutoCondition().
Get(&cond) Get(&cond)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist") assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1", "op": OpEqual, "value": "1"}). has, err = testEngine.Where(builder.Eq{"`col_name`": "col1", "`op`": OpEqual, "`value`": "1"}).
NoAutoCondition(). NoAutoCondition().
Get(&cond) Get(&cond)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist") assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Neq{"op": OpEqual})). And(builder.Neq{"`op`": OpEqual})).
NoAutoCondition(). NoAutoCondition().
Get(&cond) Get(&cond)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, false, has, "records should not exist") assert.Equal(t, false, has, "records should not exist")
var conds []Condition var conds []Condition
err = testEngine.Where(builder.Eq{"col_name": "col1"}. err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Eq{"op": OpEqual})). And(builder.Eq{"`op`": OpEqual})).
Find(&conds) Find(&conds)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist") assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0) conds = make([]Condition, 0)
err = testEngine.Where(builder.Like{"col_name", "col"}).Find(&conds) err = testEngine.Where(builder.Like{"`col_name`", "col"}).Find(&conds)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist") assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0) conds = make([]Condition, 0)
err = testEngine.Where(builder.Expr("col_name = ?", "col1")).Find(&conds) err = testEngine.Where(builder.Expr("`col_name` = ?", "col1")).Find(&conds)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist") assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0) conds = make([]Condition, 0)
err = testEngine.Where(builder.In("col_name", "col1", "col2")).Find(&conds) err = testEngine.Where(builder.In("`col_name`", "col1", "col2")).Find(&conds)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist") assert.EqualValues(t, 1, len(conds), "records should exist")
@ -91,8 +91,8 @@ func TestBuilder(t *testing.T) {
// complex condtions // complex condtions
var where = builder.NewCond() var where = builder.NewCond()
if true { if true {
where = where.And(builder.Eq{"col_name": "col1"}) where = where.And(builder.Eq{"`col_name`": "col1"})
where = where.Or(builder.And(builder.In("col_name", "col1", "col2"), builder.Expr("col_name = ?", "col1"))) where = where.Or(builder.And(builder.In("`col_name`", "col1", "col2"), builder.Expr("`col_name` = ?", "col1")))
} }
conds = make([]Condition, 0) conds = make([]Condition, 0)