From 78d3504360ba174fd6255bd5c80f9c5637145632 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 12 Aug 2021 17:43:26 +0800 Subject: [PATCH] Fix tests --- dialects/db2.go | 53 +++++++++++++++++++------------ dialects/dialect.go | 37 ++++++++++++++------- dialects/postgres.go | 37 ++++++++++----------- integrations/engine_test.go | 2 +- integrations/session_cond_test.go | 26 +++++++-------- 5 files changed, 89 insertions(+), 66 deletions(-) diff --git a/dialects/db2.go b/dialects/db2.go index 74f2d750..2ea65019 100644 --- a/dialects/db2.go +++ b/dialects/db2.go @@ -38,6 +38,12 @@ func (db *db2) Version(context.Context, core.Queryer) (*schemas.Version, error) return nil, fmt.Errorf("not implementation") } +func (db *db2) Features() *DialectFeatures { + return &DialectFeatures{ + DefaultClause: "WITH DEFAULT", + } +} + func (db *db2) ColumnTypeKind(t string) int { switch strings.ToUpper(t) { case "DATE", "DATETIME", "DATETIME2", "TIME": @@ -58,9 +64,9 @@ func (db *db2) SQLType(c *schemas.Column) string { res = schemas.SmallInt return res case schemas.UnsignedBigInt: - res = schemas.BigInt + return schemas.BigInt case schemas.UnsignedInt: - res = schemas.BigInt + return schemas.BigInt case schemas.Bit, schemas.Bool, schemas.Boolean: res = schemas.Boolean return res @@ -119,35 +125,42 @@ func (db *db2) IndexOnTable() bool { } func (db *db2) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string - sql = "CREATE TABLE " if tableName == "" { tableName = table.Name } - sql += db.Quoter().Quote(tableName) + " (" + quoter := db.Quoter() + var b strings.Builder + b.WriteString("CREATE TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") - pkList := table.PrimaryKeys - - for _, colName := range table.ColumnsSeq() { + for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db, col, false) - sql += s - if col.IsAutoIncrement { - sql += " GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1 )" + if !col.DefaultIsEmpty { + col.Nullable = false + } + s, _ := ColumnString(db, col, false) + b.WriteString(s) + + if col.IsAutoIncrement { + b.WriteString(" GENERATED BY DEFAULT AS IDENTITY (START WITH 1, INCREMENT BY 1)") + } + + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") } - sql = strings.TrimSpace(sql) - sql += ", " } - if len(pkList) > 0 { - sql += "PRIMARY KEY ( " - sql += db.Quoter().Join(pkList, ",") - sql += " ), " + if len(table.PrimaryKeys) > 0 { + b.WriteString(", PRIMARY KEY (") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(")") } - sql = sql[:len(sql)-2] + ")" - return []string{sql}, false + b.WriteString(")") + + return []string{b.String()}, true } func (db *db2) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { diff --git a/dialects/dialect.go b/dialects/dialect.go index 5d65b1f6..12220523 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -32,17 +32,21 @@ type URI struct { // SetSchema set schema func (uri *URI) SetSchema(schema string) { - // hack me if uri.DBType == schemas.POSTGRES { uri.Schema = strings.TrimSpace(schema) } } +type DialectFeatures struct { + DefaultClause string // default key word +} + // Dialect represents a kind of database type Dialect interface { Init(*URI) error URI() *URI Version(ctx context.Context, queryer core.Queryer) (*schemas.Version, error) + Features() *DialectFeatures SQLType(*schemas.Column) string Alias(string) string // return what a sql type's alias of @@ -103,6 +107,12 @@ func (db *Base) URI() *URI { return db.uri } +func (db *Base) Features() *DialectFeatures { + return &DialectFeatures{ + DefaultClause: "DEFAULT", + } +} + // DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote @@ -253,43 +263,46 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } - if err := bd.WriteByte(' '); err != nil { - return "", err - } - if includePrimaryKey && col.IsPrimaryKey { - if _, err := bd.WriteString("PRIMARY KEY "); err != nil { + if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err } if col.IsAutoIncrement { - if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { + if err := bd.WriteByte(' '); err != nil { return "", err } - if err := bd.WriteByte(' '); err != nil { + if _, err := bd.WriteString(dialect.AutoIncrStr()); err != nil { return "", err } } } if col.Default != "" { - if _, err := bd.WriteString("DEFAULT "); err != nil { + if err := bd.WriteByte(' '); err != nil { return "", err } - if _, err := bd.WriteString(col.Default); err != nil { + if _, err := bd.WriteString(dialect.Features().DefaultClause); err != nil { return "", err } if err := bd.WriteByte(' '); err != nil { return "", err } + if _, err := bd.WriteString(col.Default); err != nil { + return "", err + } + } + + if err := bd.WriteByte(' '); err != nil { + return "", err } if col.Nullable { - if _, err := bd.WriteString("NULL "); err != nil { + if _, err := bd.WriteString("NULL"); err != nil { return "", err } } else { - if _, err := bd.WriteString("NOT NULL "); err != nil { + if _, err := bd.WriteString("NOT NULL"); err != nil { return "", err } } diff --git a/dialects/postgres.go b/dialects/postgres.go index 96ebfc85..d40f4aad 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -966,38 +966,35 @@ func (db *postgres) AutoIncrStr() string { } func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { - var sql string - sql = "CREATE TABLE IF NOT EXISTS " if tableName == "" { tableName = table.Name } quoter := db.Quoter() - sql += quoter.Quote(tableName) - sql += " (" + var b strings.Builder + b.WriteString("CREATE TABLE IF NOT EXIST ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" (") - if len(table.ColumnsSeq()) > 0 { - pkList := table.PrimaryKeys + for i, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + s, _ := ColumnString(db, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + b.WriteString(s) - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - s, _ := ColumnString(db, col, col.IsPrimaryKey && len(pkList) == 1) - sql += s - sql = strings.TrimSpace(sql) - sql += ", " + if len(table.PrimaryKeys) > 1 { + b.WriteString("PRIMARY KEY ( ") + b.WriteString(quoter.Join(table.PrimaryKeys, ",")) + b.WriteString(" )") } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += quoter.Join(pkList, ",") - sql += " ), " + if i != len(table.ColumnsSeq())-1 { + b.WriteString(", ") } - - sql = sql[:len(sql)-2] } - sql += ")" - return []string{sql}, true + b.WriteString(")") + + return []string{b.String()}, false } func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 02b35a2c..cfcdd985 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -126,7 +126,7 @@ func TestDump(t *testing.T) { assert.NoError(t, err) assert.NoError(t, sess.Commit()) - for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL} { + for _, tp := range []schemas.DBType{schemas.SQLITE, schemas.MYSQL, schemas.POSTGRES, schemas.MSSQL, schemas.DB2} { name := fmt.Sprintf("dump_%v.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.DumpAllToFile(name, tp)) diff --git a/integrations/session_cond_test.go b/integrations/session_cond_test.go index a0a91cad..05972ecc 100644 --- a/integrations/session_cond_test.go +++ b/integrations/session_cond_test.go @@ -37,49 +37,49 @@ func TestBuilder(t *testing.T) { assert.NoError(t, err) var cond Condition - has, err := testEngine.Where(builder.Eq{"col_name": "col1"}).Get(&cond) + has, err := testEngine.Where(builder.Eq{"`col_name`": "col1"}).Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Eq{"op": OpEqual})). + has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}. + And(builder.Eq{"`op`": OpEqual})). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1", "op": OpEqual, "value": "1"}). + has, err = testEngine.Where(builder.Eq{"`col_name`": "col1", "`op`": OpEqual, "`value`": "1"}). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, true, has, "records should exist") - has, err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Neq{"op": OpEqual})). + has, err = testEngine.Where(builder.Eq{"`col_name`": "col1"}. + And(builder.Neq{"`op`": OpEqual})). NoAutoCondition(). Get(&cond) assert.NoError(t, err) assert.Equal(t, false, has, "records should not exist") var conds []Condition - err = testEngine.Where(builder.Eq{"col_name": "col1"}. - And(builder.Eq{"op": OpEqual})). + err = testEngine.Where(builder.Eq{"`col_name`": "col1"}. + And(builder.Eq{"`op`": OpEqual})). Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.Like{"col_name", "col"}).Find(&conds) + err = testEngine.Where(builder.Like{"`col_name`", "col"}).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.Expr("col_name = ?", "col1")).Find(&conds) + err = testEngine.Where(builder.Expr("`col_name` = ?", "col1")).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") conds = make([]Condition, 0) - err = testEngine.Where(builder.In("col_name", "col1", "col2")).Find(&conds) + err = testEngine.Where(builder.In("`col_name`", "col1", "col2")).Find(&conds) assert.NoError(t, err) assert.EqualValues(t, 1, len(conds), "records should exist") @@ -91,8 +91,8 @@ func TestBuilder(t *testing.T) { // complex condtions var where = builder.NewCond() if true { - where = where.And(builder.Eq{"col_name": "col1"}) - where = where.Or(builder.And(builder.In("col_name", "col1", "col2"), builder.Expr("col_name = ?", "col1"))) + where = where.And(builder.Eq{"`col_name`": "col1"}) + where = where.Or(builder.And(builder.In("`col_name`", "col1", "col2"), builder.Expr("`col_name` = ?", "col1"))) } conds = make([]Condition, 0)