diff --git a/dialects/dameng.go b/dialects/dameng.go index 8bed7f13..23d1836a 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) { // ModifyColumnSQL returns a SQL to modify SQL func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) } @@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl } } - s, _ := ColumnString(db, col, false) + s, _ := ColumnString(db, col, false, false) if _, err := b.WriteString(s); err != nil { return "", false, err } diff --git a/dialects/dialect.go b/dialects/dialect.go index 70d599e6..b74a4636 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -135,7 +135,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { @@ -209,7 +209,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa // AddColumnSQL returns a SQL to add a column func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, true) + s, _ := ColumnString(db.dialect, col, true, false) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) } @@ -241,7 +241,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { // ModifyColumnSQL returns a SQL to modify SQL func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } @@ -254,9 +254,7 @@ func (db *Base) ForUpdateSQL(query string) string { func (db *Base) SetParams(params map[string]string) { } -var ( - dialects = map[string]func() Dialect{} -) +var dialects = map[string]func() Dialect{} // RegisterDialect register database dialect func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { @@ -307,7 +305,7 @@ func init() { } // ColumnString generate column description string according dialect -func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { +func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) { bd := strings.Builder{} if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { @@ -322,6 +320,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } + if supportCollation && col.Collate != "" { + if _, err := bd.WriteString(" COLLATE "); err != nil { + return "", err + } + if _, err := bd.WriteString(col.Collate); err != nil { + return "", err + } + } + if includePrimaryKey && col.IsPrimaryKey { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err diff --git a/dialects/mssql.go b/dialects/mssql.go index 1b6fe692..991a503d 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) { } func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) } @@ -646,7 +646,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { diff --git a/dialects/mysql.go b/dialects/mysql.go index b941a41b..0169fb51 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -380,7 +380,7 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { quoter := db.dialect.Quoter() - s, _ := ColumnString(db, col, true) + s, _ := ColumnString(db, col, true, true) var b strings.Builder b.WriteString("ALTER TABLE ") quoter.QuoteTo(&b, tableName) @@ -394,6 +394,12 @@ func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { return b.String() } +// ModifyColumnSQL returns a SQL to modify SQL +func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false, true) + return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) +} + func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{db.uri.DBName, tableName} alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " + @@ -646,7 +652,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) if len(col.Comment) > 0 { diff --git a/dialects/oracle.go b/dialects/oracle.go index b0c5c38f..a5f8a5b2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - s, _ := ColumnString(db, col, false) + s, _ := ColumnString(db, col, false, false) sql += s // } sql = strings.TrimSpace(sql) diff --git a/dialects/postgres.go b/dialects/postgres.go index 5efe54f4..942ab934 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl } func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, true) + s, _ := ColumnString(db.dialect, col, true, false) quoter := db.dialect.Quoter() addColumnSQL := "" diff --git a/integrations/session_schema_test.go b/integrations/schema_test.go similarity index 91% rename from integrations/session_schema_test.go rename to integrations/schema_test.go index 3212d027..171a6d00 100644 --- a/integrations/session_schema_test.go +++ b/integrations/schema_test.go @@ -325,14 +325,14 @@ func TestIsTableEmpty(t *testing.T) { type PictureEmpty struct { Id int64 - Url string `xorm:"unique"` //image's url + Url string `xorm:"unique"` // image's url Title string Description string Created time.Time `xorm:"created"` ILike int PageView int From_url string // nolint - Pre_url string `xorm:"unique"` //pre view image's url + Pre_url string `xorm:"unique"` // pre view image's url Uid int64 } @@ -458,7 +458,7 @@ func TestSync2_2(t *testing.T) { assert.NoError(t, PrepareEngine()) - var tableNames = make(map[string]bool) + tableNames := make(map[string]bool) for i := 0; i < 10; i++ { tableName := fmt.Sprintf("test_sync2_index_%d", i) tableNames[tableName] = true @@ -536,3 +536,58 @@ func TestModifyColum(t *testing.T) { _, err := testEngine.Exec(alterSQL) assert.NoError(t, err) } + +func TestCollate(t *testing.T) { + type TestCollateColumn struct { + Id int64 + UserId int64 `xorm:"unique(s)"` + Name string `xorm:"varchar(20) unique(s) collate utf8mb4_general_ci"` + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(TestCollateColumn)) + + _, err := testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "test", + }) + assert.NoError(t, err) + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "Test", + }) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Since SQLITE don't support modify column SQL, currrently just ignore + if testEngine.Dialect().URI().DBType == schemas.SQLITE { + return + } + + alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{ + Name: "name", + SQLType: schemas.SQLType{ + Name: "VARCHAR", + }, + Length: 20, + Nullable: true, + DefaultIsEmpty: true, + Collate: "utf8mb4_bin", + }) + _, err = testEngine.Exec(alterSQL) + assert.NoError(t, err) + + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "test1", + }) + assert.NoError(t, err) + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "Test1", + }) + assert.NoError(t, err) +} diff --git a/schemas/column.go b/schemas/column.go index 5a579e92..c653ffe0 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -45,6 +45,7 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + Collate string } // NewColumn creates a new column diff --git a/tags/tag.go b/tags/tag.go index 41d525e1..4350ecbc 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{ "COMMENT": CommentTagHandler, "EXTENDS": ExtendsTagHandler, "UNSIGNED": UnsignedTagHandler, + "COLLATE": CollateTagHandler, } func init() { @@ -282,6 +283,13 @@ func CommentTagHandler(ctx *Context) error { return nil } +func CollateTagHandler(ctx *Context) error { + if len(ctx.params) > 0 { + ctx.col.Collate = strings.Trim(ctx.params[0], "' ") + } + return nil +} + // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} diff --git a/tags/tag_test.go b/tags/tag_test.go index 3ceeefd1..c6aa2d84 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -11,68 +11,82 @@ import ( ) func TestSplitTag(t *testing.T) { - var cases = []struct { + cases := []struct { tag string tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ - { - name: "not", - }, - { - name: "null", - }, - { - name: "default", - }, - { - name: "'2000-01-01 00:00:00'", - }, - { - name: "TIMESTAMP", - }, - }, - }, - {"TEXT", []tag{ - { - name: "TEXT", - }, - }, - }, - {"default('2000-01-01 00:00:00')", []tag{ - { - name: "default", - params: []string{ - "'2000-01-01 00:00:00'", + { + "not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", }, }, }, - }, - {"json binary", []tag{ - { - name: "json", - }, - { - name: "binary", + { + "TEXT", []tag{ + { + name: "TEXT", + }, }, }, - }, - {"numeric(10, 2)", []tag{ - { - name: "numeric", - params: []string{"10", "2"}, + { + "default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, }, }, - }, - {"numeric(10, 2) notnull", []tag{ - { - name: "numeric", - params: []string{"10", "2"}, - }, - { - name: "notnull", + { + "json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, }, }, + { + "numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + { + "numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, + { + "collate utf8mb4_bin", []tag{ + { + name: "collate", + params: []string{"utf8mb4_bin"}, + }, + }, }, }