Fix tests

This commit is contained in:
Lunny Xiao 2021-08-12 17:43:26 +08:00
parent e502385b12
commit 78d3504360
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
5 changed files with 89 additions and 66 deletions

View File

@ -38,6 +38,12 @@ func (db *db2) Version(context.Context, core.Queryer) (*schemas.Version, error)
return nil, fmt.Errorf("not implementation")
}
func (db *db2) Features() *DialectFeatures {
return &DialectFeatures{
DefaultClause: "WITH DEFAULT",
}
}
func (db *db2) ColumnTypeKind(t string) int {
switch strings.ToUpper(t) {
case "DATE", "DATETIME", "DATETIME2", "TIME":
@ -58,9 +64,9 @@ func (db *db2) SQLType(c *schemas.Column) string {
res = schemas.SmallInt
return res
case schemas.UnsignedBigInt:
res = schemas.BigInt
return schemas.BigInt
case schemas.UnsignedInt:
res = schemas.BigInt
return schemas.BigInt
case schemas.Bit, schemas.Bool, schemas.Boolean:
res = schemas.Boolean
return res
@ -119,35 +125,42 @@ func (db *db2) IndexOnTable() bool {
}
func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string
sql = "CREATE TABLE "
if tableName == "" {
tableName = table.Name
}
sql += db.Quoter().Quote(tableName) + " ("
quoter := db.Quoter()
var b strings.Builder
b.WriteString("CREATE TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
s, _ := ColumnString(db, col, false)
sql += s
if col.IsAutoIncrement {
sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )"
if !col.DefaultIsEmpty {
col.Nullable = false
}
s, _ := ColumnString(db, col, false)
b.WriteString(s)
if col.IsAutoIncrement {
b.WriteString(" GENERATED BY DEFAULT AS IDENTITY (START WITH 1, INCREMENT BY 1)")
}
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 0 {
sql += "PRIMARY KEY ( "
sql += db.Quoter().Join(pkList, ",")
sql += " ), "
if len(table.PrimaryKeys) > 0 {
b.WriteString(", PRIMARY KEY (")
b.WriteString(quoter.Join(table.PrimaryKeys, ","))
b.WriteString(")")
}
sql = sql[:len(sql)-2] + ")"
return []string{sql}, false
b.WriteString(")")
return []string{b.String()}, true
}
func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -32,17 +32,21 @@ type URI struct {
// SetSchema set schema
func (uri *URI) SetSchema(schema string) {
// hack me
if uri.DBType == schemas.POSTGRES {
uri.Schema = strings.TrimSpace(schema)
}
}
type DialectFeatures struct {
DefaultClause string // default key word
}
// Dialect represents a kind of database
type Dialect interface {
Init(*URI) error
URI() *URI
Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error)
Features() *DialectFeatures
SQLType(*schemas.Column) string
Alias(string) string // return what a sql type's alias of
@ -103,6 +107,12 @@ func (db *Base) URI() *URI {
return db.uri
}
func (db *Base) Features() *DialectFeatures {
return &DialectFeatures{
DefaultClause: "DEFAULT",
}
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote
@ -253,43 +263,46 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString("PRIMARY KEY "); err != nil {
if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err
}
if col.IsAutoIncrement {
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if err := bd.WriteByte(' '); err != nil {
if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil {
return "", err
}
}
}
if col.Default != "" {
if _, err := bd.WriteString("DEFAULT "); err != nil {
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Default); err != nil {
if _, err := bd.WriteString(dialect.Features().DefaultClause); err != nil {
return "", err
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Default); err != nil {
return "", err
}
}
if err := bd.WriteByte(' '); err != nil {
return "", err
}
if col.Nullable {
if _, err := bd.WriteString("NULL "); err != nil {
if _, err := bd.WriteString("NULL"); err != nil {
return "", err
}
} else {
if _, err := bd.WriteString("NOT NULL "); err != nil {
if _, err := bd.WriteString("NOT NULL"); err != nil {
return "", err
}
}

View File

@ -966,38 +966,35 @@ func (db *postgres) AutoIncrStr() string {
}
func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
quoter := db.Quoter()
sql += quoter.Quote(tableName)
sql += " ("
var b strings.Builder
b.WriteString("CREATE TABLE IF NOT EXIST ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" (")
if len(table.ColumnsSeq()) > 0 {
pkList := table.PrimaryKeys
for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1)
b.WriteString(s)
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1)
sql += s
sql = strings.TrimSpace(sql)
sql += ", "
if len(table.PrimaryKeys) > 1 {
b.WriteString("PRIMARY KEY ( ")
b.WriteString(quoter.Join(table.PrimaryKeys, ","))
b.WriteString(" )")
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += quoter.Join(pkList, ",")
sql += " ), "
if i != len(table.ColumnsSeq())-1 {
b.WriteString(", ")
}
sql = sql[:len(sql)-2]
}
sql += ")"
return []string{sql}, true
b.WriteString(")")
return []string{b.String()}, false
}
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -126,7 +126,7 @@ func TestDump(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, sess.Commit())
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} {
for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL, schemas.DB2} {
name := fmt.Sprintf("dump_%v.sql", tp)
t.Run(name, func(t *testing.T) {
assert.NoError(t, testEngine.DumpAllToFile(name, tp))

View File

@ -37,49 +37,49 @@ func TestBuilder(t *testing.T) {
assert.NoError(t, err)
var cond Condition
has, err := testEngine.Where(builder.Eq{"col_name": "col1"}).Get(&cond)
has, err := testEngine.Where(builder.Eq{"`col_name`": "col1"}).Get(&cond)
assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1"}.
And(builder.Eq{"op": OpEqual})).
has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Eq{"`op`": OpEqual})).
NoAutoCondition().
Get(&cond)
assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1", "op": OpEqual, "value": "1"}).
has, err = testEngine.Where(builder.Eq{"`col_name`": "col1", "`op`": OpEqual, "`value`": "1"}).
NoAutoCondition().
Get(&cond)
assert.NoError(t, err)
assert.Equal(t, true, has, "records should exist")
has, err = testEngine.Where(builder.Eq{"col_name": "col1"}.
And(builder.Neq{"op": OpEqual})).
has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Neq{"`op`": OpEqual})).
NoAutoCondition().
Get(&cond)
assert.NoError(t, err)
assert.Equal(t, false, has, "records should not exist")
var conds []Condition
err = testEngine.Where(builder.Eq{"col_name": "col1"}.
And(builder.Eq{"op": OpEqual})).
err = testEngine.Where(builder.Eq{"`col_name`": "col1"}.
And(builder.Eq{"`op`": OpEqual})).
Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.Where(builder.Like{"col_name", "col"}).Find(&conds)
err = testEngine.Where(builder.Like{"`col_name`", "col"}).Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.Where(builder.Expr("col_name = ?", "col1")).Find(&conds)
err = testEngine.Where(builder.Expr("`col_name` = ?", "col1")).Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
conds = make([]Condition, 0)
err = testEngine.Where(builder.In("col_name", "col1", "col2")).Find(&conds)
err = testEngine.Where(builder.In("`col_name`", "col1", "col2")).Find(&conds)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(conds), "records should exist")
@ -91,8 +91,8 @@ func TestBuilder(t *testing.T) {
// complex condtions
var where = builder.NewCond()
if true {
where = where.And(builder.Eq{"col_name": "col1"})
where = where.Or(builder.And(builder.In("col_name", "col1", "col2"), builder.Expr("col_name = ?", "col1")))
where = where.And(builder.Eq{"`col_name`": "col1"})
where = where.Or(builder.And(builder.In("`col_name`", "col1", "col2"), builder.Expr("`col_name` = ?", "col1")))
}
conds = make([]Condition, 0)