Mysql support a new tag Collate (#2283)

Fix #237
Fix #2179

Reviewed-on: https://gitea.com/xorm/xorm/pulls/2283
This commit is contained in:
Lunny Xiao 2023-07-01 03:40:09 +00:00
parent 18f8e7a86c
commit d29fe49933
13 changed files with 289 additions and 78 deletions

View File

@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s)
} }
@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
} }
} }
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
if _, err := b.WriteString(s); err != nil { if _, err := b.WriteString(s); err != nil {
return "", false, err return "", false, err
} }

View File

@ -135,7 +135,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {
@ -209,7 +209,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
// AddColumnSQL returns a SQL to add a column // AddColumnSQL returns a SQL to add a column
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s)
} }
@ -241,7 +241,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
} }
@ -254,9 +254,7 @@ func (db *Base) ForUpdateSQL(query string) string {
func (db *Base) SetParams(params map[string]string) { func (db *Base) SetParams(params map[string]string) {
} }
var ( var dialects = map[string]func() Dialect{}
dialects = map[string]func() Dialect{}
)
// RegisterDialect register database dialect // RegisterDialect register database dialect
func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
@ -307,7 +305,7 @@ func init() {
} }
// ColumnString generate column description string according dialect // ColumnString generate column description string according dialect
func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) {
bd := strings.Builder{} bd := strings.Builder{}
if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil {
@ -322,6 +320,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err return "", err
} }
if supportCollation && col.Collation != "" {
if _, err := bd.WriteString(" COLLATE "); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Collation); err != nil {
return "", err
}
}
if includePrimaryKey && col.IsPrimaryKey { if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err return "", err

View File

@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) {
} }
func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, true)
return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s)
} }
@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(p.is_primary_key, 0), a.is_identity as is_identity ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name
from sys.columns a from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id left join sys.syscomments c on a.default_object_id=c.id
@ -475,9 +475,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
var collation *string
var maxLen, precision, scale int64 var maxLen, precision, scale int64
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -499,6 +500,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.Length = maxLen col.Length = maxLen
} }
if collation != nil {
col.Collation = *collation
}
switch ct { switch ct {
case "DATETIMEOFFSET": case "DATETIMEOFFSET":
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
@ -646,7 +650,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {

View File

@ -380,7 +380,7 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
s, _ := ColumnString(db, col, true) s, _ := ColumnString(db, col, true, true)
var b strings.Builder var b strings.Builder
b.WriteString("ALTER TABLE ") b.WriteString("ALTER TABLE ")
quoter.QuoteTo(&b, tableName) quoter.QuoteTo(&b, tableName)
@ -394,6 +394,12 @@ func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
return b.String() return b.String()
} }
// ModifyColumnSQL returns a SQL to modify SQL
func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false, true)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
}
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{db.uri.DBName, tableName} args := []interface{}{db.uri.DBName, tableName}
alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " + alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " +
@ -404,7 +410,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
"SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))"
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
@ -422,8 +428,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
var columnName, nullableStr, colType, colKey, extra, comment string var columnName, nullableStr, colType, colKey, extra, comment string
var alreadyQuoted, isUnsigned bool var alreadyQuoted, isUnsigned bool
var colDefault, maxLength *string var colDefault, maxLength, collation *string
err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -439,6 +445,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.DefaultIsEmpty = true col.DefaultIsEmpty = true
} }
if collation != nil {
col.Collation = *collation
}
fields := strings.Fields(colType) fields := strings.Fields(colType)
if len(fields) == 2 && fields[1] == "unsigned" { if len(fields) == 2 && fields[1] == "unsigned" {
@ -531,7 +540,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{db.uri.DBName} args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
@ -543,9 +552,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name, engine string var name, engine, collation string
var autoIncr, comment *string var autoIncr, comment *string
err = rows.Scan(&name, &engine, &autoIncr, &comment) err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -555,6 +564,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
table.Comment = *comment table.Comment = *comment
} }
table.StoreEngine = engine table.StoreEngine = engine
table.Collation = collation
tables = append(tables, table) tables = append(tables, table)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -646,7 +656,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {

View File

@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
/*if col.IsPrimaryKey && len(pkList) == 1 { /*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect) sql += col.String(b.dialect)
} else {*/ } else {*/
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
sql += s sql += s
// } // }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)

View File

@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl
} }
func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
addColumnSQL := "" addColumnSQL := ""

View File

@ -5,6 +5,7 @@
package integrations package integrations
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -325,14 +326,14 @@ func TestIsTableEmpty(t *testing.T) {
type PictureEmpty struct { type PictureEmpty struct {
Id int64 Id int64
Url string `xorm:"unique"` //image's url Url string `xorm:"unique"` // image's url
Title string Title string
Description string Description string
Created time.Time `xorm:"created"` Created time.Time `xorm:"created"`
ILike int ILike int
PageView int PageView int
From_url string // nolint From_url string // nolint
Pre_url string `xorm:"unique"` //pre view image's url Pre_url string `xorm:"unique"` // pre view image's url
Uid int64 Uid int64
} }
@ -458,7 +459,7 @@ func TestSync2_2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
var tableNames = make(map[string]bool) tableNames := make(map[string]bool)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("test_sync2_index_%d", i) tableName := fmt.Sprintf("test_sync2_index_%d", i)
tableNames[tableName] = true tableNames[tableName] = true
@ -536,3 +537,111 @@ func TestModifyColum(t *testing.T) {
_, err := testEngine.Exec(alterSQL) _, err := testEngine.Exec(alterSQL)
assert.NoError(t, err) assert.NoError(t, err)
} }
type TestCollateColumn struct {
Id int64
UserId int64 `xorm:"unique(s)"`
Name string `xorm:"varchar(20) unique(s)"`
dbtype string `xorm:"-"`
}
func (t TestCollateColumn) TableCollations() []*schemas.Collation {
if t.dbtype == string(schemas.MYSQL) {
return []*schemas.Collation{
{
Name: "utf8mb4_general_ci",
Column: "name",
},
}
} else if t.dbtype == string(schemas.MSSQL) {
return []*schemas.Collation{
{
Name: "Latin1_General_CI_AS",
Column: "name",
},
}
}
return nil
}
func TestCollate(t *testing.T) {
assert.NoError(t, PrepareEngine())
assertSync(t, &TestCollateColumn{
dbtype: string(testEngine.Dialect().URI().DBType),
})
_, err := testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test",
})
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
ver, err1 := testEngine.DBVersion()
assert.NoError(t, err1)
tables, err1 := testEngine.DBMetas()
assert.NoError(t, err1)
for _, table := range tables {
if table.Name == "test_collate_column" {
col := table.GetColumn("name")
if col == nil {
assert.Error(t, errors.New("not found column"))
return
}
// tidb doesn't follow utf8mb4_general_ci
if col.Collation == "utf8mb4_general_ci" && ver.Edition != "TiDB" {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
break
}
}
} else if testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Since SQLITE don't support modify column SQL, currrently just ignore
if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL {
return
}
var newCollation string
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
newCollation = "utf8mb4_bin"
} else if testEngine.Dialect().URI().DBType != schemas.MSSQL {
newCollation = "Latin1_General_CS_AS"
} else {
return
}
alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{
Name: "name",
SQLType: schemas.SQLType{
Name: "VARCHAR",
},
Length: 20,
Nullable: true,
DefaultIsEmpty: true,
Collation: newCollation,
})
_, err = testEngine.Exec(alterSQL)
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test1",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test1",
})
assert.NoError(t, err)
}

10
schemas/collation.go Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
type Collation struct {
Name string
Column string // blank means it's a table collation
}

View File

@ -45,6 +45,7 @@ type Column struct {
DisableTimeZone bool DisableTimeZone bool
TimeZone *time.Location // column specified time zone TimeZone *time.Location // column specified time zone
Comment string Comment string
Collation string
} }
// NewColumn creates a new column // NewColumn creates a new column

View File

@ -27,6 +27,7 @@ type Table struct {
StoreEngine string StoreEngine string
Charset string Charset string
Comment string Comment string
Collation string
} }
// NewEmptyTable creates an empty table // NewEmptyTable creates an empty table
@ -36,7 +37,8 @@ func NewEmptyTable() *Table {
// NewTable creates a new Table object // NewTable creates a new Table object
func NewTable(name string, t reflect.Type) *Table { func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t, return &Table{
Name: name, Type: t,
columnsSeq: make([]string, 0), columnsSeq: make([]string, 0),
columns: make([]*Column, 0), columns: make([]*Column, 0),
columnsMap: make(map[string][]*Column), columnsMap: make(map[string][]*Column),

View File

@ -31,6 +31,12 @@ type TableIndices interface {
var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem()
type TableCollations interface {
TableCollations() []*schemas.Collation
}
var tpTableCollations = reflect.TypeOf((*TableCollations)(nil)).Elem()
// Parser represents a parser for xorm tag // Parser represents a parser for xorm tag
type Parser struct { type Parser struct {
identifier string identifier string
@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
} }
} }
collations := tableCollations(v)
for _, collation := range collations {
if collation.Name == "" {
continue
}
if collation.Column == "" {
table.Collation = collation.Name
} else {
col := table.GetColumn(collation.Column)
if col == nil {
return nil, ErrUnsupportedType
}
col.Collation = collation.Name // this may override definition in struct tag
}
}
return table, nil return table, nil
} }
@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index {
} }
return nil return nil
} }
func tableCollations(v reflect.Value) []*schemas.Collation {
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
} else if v.CanAddr() {
v1 := v.Addr()
if v1.Type().Implements(tpTableCollations) {
return v1.Interface().(TableCollations).TableCollations()
}
}
return nil
}

View File

@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{
"COMMENT": CommentTagHandler, "COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler, "EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler, "UNSIGNED": UnsignedTagHandler,
"COLLATE": CollateTagHandler,
} }
func init() { func init() {
@ -282,6 +283,16 @@ func CommentTagHandler(ctx *Context) error {
return nil return nil
} }
func CollateTagHandler(ctx *Context) error {
if len(ctx.params) > 0 {
ctx.col.Collation = ctx.params[0]
} else {
ctx.col.Collation = ctx.nextTag
ctx.ignoreNext = true
}
return nil
}
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}

View File

@ -11,11 +11,12 @@ import (
) )
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { cases := []struct {
tag string tag string
tags []tag tags []tag
}{ }{
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ {
"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
{ {
name: "not", name: "not",
}, },
@ -33,13 +34,15 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"TEXT", []tag{ {
"TEXT", []tag{
{ {
name: "TEXT", name: "TEXT",
}, },
}, },
}, },
{"default('2000-01-01 00:00:00')", []tag{ {
"default('2000-01-01 00:00:00')", []tag{
{ {
name: "default", name: "default",
params: []string{ params: []string{
@ -48,7 +51,8 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"json binary", []tag{ {
"json binary", []tag{
{ {
name: "json", name: "json",
}, },
@ -57,14 +61,16 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"numeric(10, 2)", []tag{ {
"numeric(10, 2)", []tag{
{ {
name: "numeric", name: "numeric",
params: []string{"10", "2"}, params: []string{"10", "2"},
}, },
}, },
}, },
{"numeric(10, 2) notnull", []tag{ {
"numeric(10, 2) notnull", []tag{
{ {
name: "numeric", name: "numeric",
params: []string{"10", "2"}, params: []string{"10", "2"},
@ -74,6 +80,16 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{
"collate utf8mb4_bin", []tag{
{
name: "collate",
},
{
name: "utf8mb4_bin",
},
},
},
} }
for _, kase := range cases { for _, kase := range cases {
@ -82,7 +98,7 @@ func TestSplitTag(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, len(tags), len(kase.tags)) assert.EqualValues(t, len(tags), len(kase.tags))
for i := 0; i < len(tags); i++ { for i := 0; i < len(tags); i++ {
assert.Equal(t, tags[i], kase.tags[i]) assert.Equal(t, kase.tags[i], tags[i])
} }
}) })
} }