diff --git a/dialects/dialect.go b/dialects/dialect.go index b74a4636..d1c5f200 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -320,11 +320,11 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, suppo return "", err } - if supportCollation && col.Collate != "" { + if supportCollation && col.Collation != "" { if _, err := bd.WriteString(" COLLATE "); err != nil { return "", err } - if _, err := bd.WriteString(col.Collate); err != nil { + if _, err := bd.WriteString(col.Collation); err != nil { return "", err } } diff --git a/dialects/mssql.go b/dialects/mssql.go index 991a503d..4350ed89 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) { } func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false, false) + s, _ := ColumnString(db.dialect, col, false, true) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) } @@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, - ISNULL(p.is_primary_key, 0), a.is_identity as is_identity + ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id left join sys.syscomments c on a.default_object_id=c.id @@ -474,10 +474,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - var name, ctype, vdefault string + var name, ctype, vdefault, collation string var maxLen, precision, scale int64 var nullable, isPK, defaultIsNull, isIncrement bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation) if err != nil { return nil, nil, err } @@ -499,6 +499,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } else { col.Length = maxLen } + col.Collation = collation switch ct { case "DATETIMEOFFSET": col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} @@ -646,7 +647,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { diff --git a/dialects/mysql.go b/dialects/mysql.go index 0169fb51..16d1fc66 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -410,7 +410,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + - alreadyQuoted + " AS NEEDS_QUOTE " + + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" @@ -426,10 +426,10 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName col := new(schemas.Column) col.Indexes = make(map[string]int) - var columnName, nullableStr, colType, colKey, extra, comment string + var columnName, nullableStr, colType, colKey, extra, comment, collation string var alreadyQuoted, isUnsigned bool var colDefault, maxLength *string - err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) + err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation) if err != nil { return nil, nil, err } @@ -445,6 +445,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } else { col.DefaultIsEmpty = true } + col.Collation = collation fields := strings.Fields(colType) if len(fields) == 2 && fields[1] == "unsigned" { @@ -537,7 +538,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{db.uri.DBName} - s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" rows, err := queryer.QueryContext(ctx, s, args...) @@ -549,9 +550,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { table := schemas.NewEmptyTable() - var name, engine string + var name, engine, collation string var autoIncr, comment *string - err = rows.Scan(&name, &engine, &autoIncr, &comment) + err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation) if err != nil { return nil, err } @@ -561,6 +562,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Comment = *comment } table.StoreEngine = engine + table.Collation = collation tables = append(tables, table) } if rows.Err() != nil { diff --git a/integrations/schema_test.go b/integrations/schema_test.go index 171a6d00..7c11c570 100644 --- a/integrations/schema_test.go +++ b/integrations/schema_test.go @@ -537,15 +537,37 @@ func TestModifyColum(t *testing.T) { assert.NoError(t, err) } -func TestCollate(t *testing.T) { - type TestCollateColumn struct { - Id int64 - UserId int64 `xorm:"unique(s)"` - Name string `xorm:"varchar(20) unique(s) collate utf8mb4_general_ci"` - } +type TestCollateColumn struct { + Id int64 + UserId int64 `xorm:"unique(s)"` + Name string `xorm:"varchar(20) unique(s)"` + dbtype string `xorm:"-"` +} +func (t TestCollateColumn) TableCollations() []*schemas.Collation { + if t.dbtype == string(schemas.MYSQL) { + return []*schemas.Collation{ + { + Name: "utf8mb4_general_ci", + Column: "name", + }, + } + } else if t.dbtype == string(schemas.MSSQL) { + return []*schemas.Collation{ + { + Name: "SQL_Latin1_General_CP1_CI", + Column: "name", + }, + } + } + return nil +} + +func TestCollate(t *testing.T) { assert.NoError(t, PrepareEngine()) - assertSync(t, new(TestCollateColumn)) + assertSync(t, &TestCollateColumn{ + dbtype: string(testEngine.Dialect().URI().DBType), + }) _, err := testEngine.Insert(&TestCollateColumn{ UserId: 1, @@ -556,14 +578,23 @@ func TestCollate(t *testing.T) { UserId: 1, Name: "Test", }) - if testEngine.Dialect().URI().DBType == schemas.MYSQL { + if testEngine.Dialect().URI().DBType == schemas.MYSQL || testEngine.Dialect().URI().DBType == schemas.MSSQL { assert.Error(t, err) } else { assert.NoError(t, err) } // Since SQLITE don't support modify column SQL, currrently just ignore - if testEngine.Dialect().URI().DBType == schemas.SQLITE { + if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL { + return + } + + var newCollation string + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + newCollation = "utf8mb4_bin" + } else if testEngine.Dialect().URI().DBType != schemas.MSSQL { + newCollation = "SQL_Latin1_General_CP1_CS" + } else { return } @@ -575,7 +606,7 @@ func TestCollate(t *testing.T) { Length: 20, Nullable: true, DefaultIsEmpty: true, - Collate: "utf8mb4_bin", + Collation: newCollation, }) _, err = testEngine.Exec(alterSQL) assert.NoError(t, err) diff --git a/schemas/collation.go b/schemas/collation.go new file mode 100644 index 00000000..acec5268 --- /dev/null +++ b/schemas/collation.go @@ -0,0 +1,10 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +type Collation struct { + Name string + Column string // blank means it's a table collation +} diff --git a/schemas/column.go b/schemas/column.go index c653ffe0..08d34b91 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -45,7 +45,7 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string - Collate string + Collation string } // NewColumn creates a new column diff --git a/schemas/table.go b/schemas/table.go index 91b33e06..5c38cc70 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -27,6 +27,7 @@ type Table struct { StoreEngine string Charset string Comment string + Collation string } // NewEmptyTable creates an empty table @@ -36,7 +37,8 @@ func NewEmptyTable() *Table { // NewTable creates a new Table object func NewTable(name string, t reflect.Type) *Table { - return &Table{Name: name, Type: t, + return &Table{ + Name: name, Type: t, columnsSeq: make([]string, 0), columns: make([]*Column, 0), columnsMap: make(map[string][]*Column), diff --git a/tags/parser.go b/tags/parser.go index 028f8d0b..a508ff83 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -31,6 +31,12 @@ type TableIndices interface { var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() +type TableCollations interface { + TableCollations() []*schemas.Collation +} + +var tpTableCollations = reflect.TypeOf((*TableIndices)(nil)).Elem() + // Parser represents a parser for xorm tag type Parser struct { identifier string @@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { } } + collations := tableCollations(v) + for _, collation := range collations { + if collation.Name == "" { + continue + } + if collation.Column == "" { + table.Collation = collation.Name + } else { + col := table.GetColumn(collation.Column) + if col == nil { + return nil, ErrUnsupportedType + } + col.Collation = collation.Name // this may override definition in struct tag + } + } + return table, nil } @@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index { } return nil } + +func tableCollations(v reflect.Value) []*schemas.Collation { + if v.Type().Implements(tpTableCollations) { + return v.Interface().(TableCollations).TableCollations() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableCollations) { + return v.Interface().(TableCollations).TableCollations() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableCollations) { + return v1.Interface().(TableCollations).TableCollations() + } + } + return nil +} diff --git a/tags/tag.go b/tags/tag.go index f369c99d..024c9c18 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -285,9 +285,9 @@ func CommentTagHandler(ctx *Context) error { func CollateTagHandler(ctx *Context) error { if len(ctx.params) > 0 { - ctx.col.Collate = ctx.params[0] + ctx.col.Collation = ctx.params[0] } else { - ctx.col.Collate = ctx.nextTag + ctx.col.Collation = ctx.nextTag ctx.ignoreNext = true } return nil