Merge branch 'xorm-master'

This commit is contained in:
brookechen 2023-07-10 20:04:18 +08:00
commit a678ef1c27
19 changed files with 670 additions and 108 deletions

View File

@ -24,7 +24,10 @@ func Interface2Interface(userLocation *time.Location, v interface{}) (interface{
return vv.String, nil return vv.String, nil
case *sql.RawBytes: case *sql.RawBytes:
if len([]byte(*vv)) > 0 { if len([]byte(*vv)) > 0 {
return []byte(*vv), nil src := []byte(*vv)
dest := make([]byte, len(src))
copy(dest, src)
return dest, nil
} }
return nil, nil return nil, nil
case *sql.NullInt32: case *sql.NullInt32:

View File

@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s)
} }
@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
} }
} }
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
if _, err := b.WriteString(s); err != nil { if _, err := b.WriteString(s); err != nil {
return "", false, err return "", false, err
} }
@ -709,7 +709,13 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
return "", false, err return "", false, err
} }
} }
if _, err := b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)); err != nil { if _, err := b.WriteString("CONSTRAINT PK_"); err != nil {
return "", false, err
}
if _, err := b.WriteString(tableName); err != nil {
return "", false, err
}
if _, err := b.WriteString(" PRIMARY KEY ("); err != nil {
return "", false, err return "", false, err
} }
if err := quoter.JoinWrite(&b, pkList, ","); err != nil { if err := quoter.JoinWrite(&b, pkList, ","); err != nil {
@ -837,7 +843,11 @@ func addSingleQuote(name string) string {
if name[0] == '\'' && name[len(name)-1] == '\'' { if name[0] == '\'' && name[len(name)-1] == '\'' {
return name return name
} }
return fmt.Sprintf("'%s'", name) var b strings.Builder
b.WriteRune('\'')
b.WriteString(name)
b.WriteRune('\'')
return b.String()
} }
func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {

View File

@ -135,7 +135,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {
@ -209,7 +209,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
// AddColumnSQL returns a SQL to add a column // AddColumnSQL returns a SQL to add a column
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s)
} }
@ -241,7 +241,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
} }
@ -254,9 +254,7 @@ func (db *Base) ForUpdateSQL(query string) string {
func (db *Base) SetParams(params map[string]string) { func (db *Base) SetParams(params map[string]string) {
} }
var ( var dialects = map[string]func() Dialect{}
dialects = map[string]func() Dialect{}
)
// RegisterDialect register database dialect // RegisterDialect register database dialect
func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
@ -307,7 +305,7 @@ func init() {
} }
// ColumnString generate column description string according dialect // ColumnString generate column description string according dialect
func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) {
bd := strings.Builder{} bd := strings.Builder{}
if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil {
@ -322,6 +320,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err return "", err
} }
if supportCollation && col.Collation != "" {
if _, err := bd.WriteString(" COLLATE "); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Collation); err != nil {
return "", err
}
}
if includePrimaryKey && col.IsPrimaryKey { if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err return "", err

View File

@ -6,7 +6,7 @@ package dialects
import ( import (
"context" "context"
"fmt" "strconv"
"strings" "strings"
) )
@ -29,10 +29,11 @@ func convertQuestionMark(sql, prefix string, start int) string {
var isMaybeLineComment bool var isMaybeLineComment bool
var isMaybeComment bool var isMaybeComment bool
var isMaybeCommentEnd bool var isMaybeCommentEnd bool
var index = start index := start
for _, c := range sql { for _, c := range sql {
if !beginSingleQuote && !isLineComment && !isComment && c == '?' { if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) buf.WriteString(prefix)
buf.WriteString(strconv.Itoa(index))
index++ index++
} else { } else {
if isMaybeLineComment { if isMaybeLineComment {

View File

@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) {
} }
func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, true)
return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s)
} }
@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(p.is_primary_key, 0), a.is_identity as is_identity ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name
from sys.columns a from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id left join sys.syscomments c on a.default_object_id=c.id
@ -475,9 +475,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
var collation *string
var maxLen, precision, scale int64 var maxLen, precision, scale int64
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -499,6 +500,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.Length = maxLen col.Length = maxLen
} }
if collation != nil {
col.Collation = *collation
}
switch ct { switch ct {
case "DATETIMEOFFSET": case "DATETIMEOFFSET":
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
@ -646,7 +650,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {

View File

@ -380,12 +380,24 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
s, _ := ColumnString(db, col, true) s, _ := ColumnString(db, col, true, true)
sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) var b strings.Builder
b.WriteString("ALTER TABLE ")
quoter.QuoteTo(&b, tableName)
b.WriteString(" ADD ")
b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {
sql += " COMMENT '" + col.Comment + "'" b.WriteString(" COMMENT '")
b.WriteString(col.Comment)
b.WriteString("'")
} }
return sql return b.String()
}
// ModifyColumnSQL returns a SQL to modify SQL
func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false, true)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
} }
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
@ -398,7 +410,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
"SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))"
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
@ -416,8 +428,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
var columnName, nullableStr, colType, colKey, extra, comment string var columnName, nullableStr, colType, colKey, extra, comment string
var alreadyQuoted, isUnsigned bool var alreadyQuoted, isUnsigned bool
var colDefault, maxLength *string var colDefault, maxLength, collation *string
err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -433,6 +445,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.DefaultIsEmpty = true col.DefaultIsEmpty = true
} }
if collation != nil {
col.Collation = *collation
}
fields := strings.Fields(colType) fields := strings.Fields(colType)
if len(fields) == 2 && fields[1] == "unsigned" { if len(fields) == 2 && fields[1] == "unsigned" {
@ -525,7 +540,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{db.uri.DBName} args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
@ -537,9 +552,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name, engine string var name, engine, collation string
var autoIncr, comment *string var autoIncr, comment *string
err = rows.Scan(&name, &engine, &autoIncr, &comment) err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -549,6 +564,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
table.Comment = *comment table.Comment = *comment
} }
table.StoreEngine = engine table.StoreEngine = engine
table.Collation = collation
tables = append(tables, table) tables = append(tables, table)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -640,7 +656,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {

View File

@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
/*if col.IsPrimaryKey && len(pkList) == 1 { /*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect) sql += col.String(b.dialect)
} else {*/ } else {*/
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
sql += s sql += s
// } // }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)

View File

@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl
} }
func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
addColumnSQL := "" addColumnSQL := ""
@ -1078,7 +1078,7 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
schema := db.getSchema() schema := db.getSchema()
if schema != "" { if schema != "" {

View File

@ -1120,21 +1120,6 @@ func (engine *Engine) UnMapType(t reflect.Type) {
engine.tagParser.ClearCacheTable(t) engine.tagParser.ClearCacheTable(t)
} }
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.Sync(beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (engine *Engine) Sync2(beans ...interface{}) error {
return engine.Sync(beans...)
}
// CreateTables create tabls according bean // CreateTables create tabls according bean
func (engine *Engine) CreateTables(beans ...interface{}) error { func (engine *Engine) CreateTables(beans ...interface{}) error {
session := engine.NewSession() session := engine.NewSession()

View File

@ -5,6 +5,7 @@
package integrations package integrations
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -325,14 +326,14 @@ func TestIsTableEmpty(t *testing.T) {
type PictureEmpty struct { type PictureEmpty struct {
Id int64 Id int64
Url string `xorm:"unique"` //image's url Url string `xorm:"unique"` // image's url
Title string Title string
Description string Description string
Created time.Time `xorm:"created"` Created time.Time `xorm:"created"`
ILike int ILike int
PageView int PageView int
From_url string // nolint From_url string // nolint
Pre_url string `xorm:"unique"` //pre view image's url Pre_url string `xorm:"unique"` // pre view image's url
Uid int64 Uid int64
} }
@ -458,7 +459,7 @@ func TestSync2_2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
var tableNames = make(map[string]bool) tableNames := make(map[string]bool)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("test_sync2_index_%d", i) tableName := fmt.Sprintf("test_sync2_index_%d", i)
tableNames[tableName] = true tableNames[tableName] = true
@ -536,3 +537,111 @@ func TestModifyColum(t *testing.T) {
_, err := testEngine.Exec(alterSQL) _, err := testEngine.Exec(alterSQL)
assert.NoError(t, err) assert.NoError(t, err)
} }
type TestCollateColumn struct {
Id int64
UserId int64 `xorm:"unique(s)"`
Name string `xorm:"varchar(20) unique(s)"`
dbtype string `xorm:"-"`
}
func (t TestCollateColumn) TableCollations() []*schemas.Collation {
if t.dbtype == string(schemas.MYSQL) {
return []*schemas.Collation{
{
Name: "utf8mb4_general_ci",
Column: "name",
},
}
} else if t.dbtype == string(schemas.MSSQL) {
return []*schemas.Collation{
{
Name: "Latin1_General_CI_AS",
Column: "name",
},
}
}
return nil
}
func TestCollate(t *testing.T) {
assert.NoError(t, PrepareEngine())
assertSync(t, &TestCollateColumn{
dbtype: string(testEngine.Dialect().URI().DBType),
})
_, err := testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test",
})
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
ver, err1 := testEngine.DBVersion()
assert.NoError(t, err1)
tables, err1 := testEngine.DBMetas()
assert.NoError(t, err1)
for _, table := range tables {
if table.Name == "test_collate_column" {
col := table.GetColumn("name")
if col == nil {
assert.Error(t, errors.New("not found column"))
return
}
// tidb doesn't follow utf8mb4_general_ci
if col.Collation == "utf8mb4_general_ci" && ver.Edition != "TiDB" {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
break
}
}
} else if testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Since SQLITE don't support modify column SQL, currrently just ignore
if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL {
return
}
var newCollation string
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
newCollation = "utf8mb4_bin"
} else if testEngine.Dialect().URI().DBType != schemas.MSSQL {
newCollation = "Latin1_General_CS_AS"
} else {
return
}
alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{
Name: "name",
SQLType: schemas.SQLType{
Name: "VARCHAR",
},
Length: 20,
Nullable: true,
DefaultIsEmpty: true,
Collation: newCollation,
})
_, err = testEngine.Exec(alterSQL)
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test1",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test1",
})
assert.NoError(t, err)
}

View File

@ -5,11 +5,13 @@
package integrations package integrations
import ( import (
"bytes"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -381,3 +383,68 @@ func TestQueryStringWithLimit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 0, len(data)) assert.EqualValues(t, 0, len(data))
} }
func TestQueryBLOBInMySQL(t *testing.T) {
assert.NoError(t, PrepareEngine())
var err error
type Avatar struct {
Id int64 `xorm:"autoincr pk"`
Avatar []byte `xorm:"BLOB"`
}
assert.NoError(t, testEngine.Sync(new(Avatar)))
testEngine.Delete(Avatar{})
repeatBytes := func(n int, b byte) []byte {
return bytes.Repeat([]byte{b}, n)
}
const N = 10
var data = []Avatar{}
for i := 0; i < N; i++ {
// allocate a []byte that is as twice big as the last one
// so that the underlying buffer will need to reallocate when querying
bs := repeatBytes(1<<(i+2), 'A'+byte(i))
data = append(data, Avatar{
Avatar: bs,
})
}
_, err = testEngine.Insert(data)
assert.NoError(t, err)
defer func() {
testEngine.Delete(Avatar{})
}()
{
records, err := testEngine.QueryInterface("select avatar from " + testEngine.Quote(testEngine.TableName("avatar", true)))
assert.NoError(t, err)
for i, record := range records {
bs := record["avatar"].([]byte)
assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3])
t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2])
}
}
{
arr := make([][]interface{}, 0)
err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr)
assert.NoError(t, err)
for i, record := range arr {
bs := record[0].([]byte)
assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3])
t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2])
}
}
{
arr := make([]map[string]interface{}, 0)
err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr)
assert.NoError(t, err)
for i, record := range arr {
bs := record["avatar"].([]byte)
assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3])
t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2])
}
}
}

View File

@ -6,6 +6,7 @@ package statements
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
@ -26,14 +27,19 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string {
return "" return ""
} }
var top string var b strings.Builder
b.WriteString("SELECT ")
pLimitN := statement.LimitN pLimitN := statement.LimitN
if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN) b.WriteString("TOP ")
b.WriteString(strconv.Itoa(*pLimitN))
b.WriteString(" ")
} }
b.WriteString(colstrs)
b.WriteString(" FROM ")
b.WriteString(sqls[1])
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) return b.String()
return newsql
} }
return "" return ""
} }
@ -54,7 +60,7 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
return "", "" return "", ""
} }
var whereStr = sqls[1] whereStr := sqls[1]
// TODO: for postgres only, if any other database? // TODO: for postgres only, if any other database?
var paraStr string var paraStr string

10
schemas/collation.go Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
type Collation struct {
Name string
Column string // blank means it's a table collation
}

View File

@ -45,6 +45,7 @@ type Column struct {
DisableTimeZone bool DisableTimeZone bool
TimeZone *time.Location // column specified time zone TimeZone *time.Location // column specified time zone
Comment string Comment string
Collation string
} }
// NewColumn creates a new column // NewColumn creates a new column

View File

@ -27,6 +27,7 @@ type Table struct {
StoreEngine string StoreEngine string
Charset string Charset string
Comment string Comment string
Collation string
} }
// NewEmptyTable creates an empty table // NewEmptyTable creates an empty table
@ -36,7 +37,8 @@ func NewEmptyTable() *Table {
// NewTable creates a new Table object // NewTable creates a new Table object
func NewTable(name string, t reflect.Type) *Table { func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t, return &Table{
Name: name, Type: t,
columnsSeq: make([]string, 0), columnsSeq: make([]string, 0),
columns: make([]*Column, 0), columns: make([]*Column, 0),
columnsMap: make(map[string][]*Column), columnsMap: make(map[string][]*Column),

273
sync.go Normal file
View File

@ -0,0 +1,273 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
type SyncOptions struct {
WarnIfDatabaseColumnMissed bool
}
type SyncResult struct{}
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.Sync(beans...)
}
// SyncWithOptions sync the database schemas according options and table structs
func (engine *Engine) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) {
session := engine.NewSession()
defer session.Close()
return session.SyncWithOptions(opts, beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (engine *Engine) Sync2(beans ...interface{}) error {
return engine.Sync(beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (session *Session) Sync2(beans ...interface{}) error {
return session.Sync(beans...)
}
// Sync synchronize structs to database tables
func (session *Session) Sync(beans ...interface{}) error {
_, err := session.SyncWithOptions(SyncOptions{
WarnIfDatabaseColumnMissed: false,
}, beans...)
return err
}
func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) {
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil {
return nil, err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var syncResult SyncResult
for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return nil, err
}
var tbName string
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
break
}
}
// this is a new table
if oriTable == nil {
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return nil, err
}
err = session.createUniques(bean)
if err != nil {
return nil, err
}
err = session.createIndexes(bean)
if err != nil {
return nil, err
}
continue
}
// this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil {
return nil, err
}
// check columns
for _, col := range table.Columns() {
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
break
}
}
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil {
return nil, err
}
continue
}
err = nil
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType)
}
}
}
} else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else if col.Comment != oriCol.Comment {
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
if col.Default != oriCol.Default {
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default)
}
}
if col.Nullable != oriCol.Nullable {
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
if err != nil {
return nil, err
}
}
foundIndexNames := make(map[string]bool)
addedNames := make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
foundIndexNames[name2] = true
break
}
}
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return nil, err
}
oriIndex = nil
}
}
if oriIndex == nil {
addedNames[name] = index
}
}
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return nil, err
}
}
}
for name, index := range addedNames {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return nil, err
}
}
if opts.WarnIfDatabaseColumnMissed {
// check all the columns which removed from struct fields but left on database tables.
for _, colName := range oriTable.ColumnsSeq() {
if table.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
}
}
}
}
return &syncResult, nil
}

View File

@ -31,6 +31,12 @@ type TableIndices interface {
var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem()
type TableCollations interface {
TableCollations() []*schemas.Collation
}
var tpTableCollations = reflect.TypeOf((*TableCollations)(nil)).Elem()
// Parser represents a parser for xorm tag // Parser represents a parser for xorm tag
type Parser struct { type Parser struct {
identifier string identifier string
@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
} }
} }
collations := tableCollations(v)
for _, collation := range collations {
if collation.Name == "" {
continue
}
if collation.Column == "" {
table.Collation = collation.Name
} else {
col := table.GetColumn(collation.Column)
if col == nil {
return nil, ErrUnsupportedType
}
col.Collation = collation.Name // this may override definition in struct tag
}
}
return table, nil return table, nil
} }
@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index {
} }
return nil return nil
} }
func tableCollations(v reflect.Value) []*schemas.Collation {
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
} else if v.CanAddr() {
v1 := v.Addr()
if v1.Type().Implements(tpTableCollations) {
return v1.Interface().(TableCollations).TableCollations()
}
}
return nil
}

View File

@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{
"COMMENT": CommentTagHandler, "COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler, "EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler, "UNSIGNED": UnsignedTagHandler,
"COLLATE": CollateTagHandler,
} }
func init() { func init() {
@ -282,6 +283,16 @@ func CommentTagHandler(ctx *Context) error {
return nil return nil
} }
func CollateTagHandler(ctx *Context) error {
if len(ctx.params) > 0 {
ctx.col.Collation = ctx.params[0]
} else {
ctx.col.Collation = ctx.nextTag
ctx.ignoreNext = true
}
return nil
}
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}

View File

@ -11,68 +11,84 @@ import (
) )
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { cases := []struct {
tag string tag string
tags []tag tags []tag
}{ }{
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ {
{ "not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
name: "not", {
}, name: "not",
{ },
name: "null", {
}, name: "null",
{ },
name: "default", {
}, name: "default",
{ },
name: "'2000-01-01 00:00:00'", {
}, name: "'2000-01-01 00:00:00'",
{ },
name: "TIMESTAMP", {
}, name: "TIMESTAMP",
},
},
{"TEXT", []tag{
{
name: "TEXT",
},
},
},
{"default('2000-01-01 00:00:00')", []tag{
{
name: "default",
params: []string{
"'2000-01-01 00:00:00'",
}, },
}, },
}, },
}, {
{"json binary", []tag{ "TEXT", []tag{
{ {
name: "json", name: "TEXT",
}, },
{
name: "binary",
}, },
}, },
}, {
{"numeric(10, 2)", []tag{ "default('2000-01-01 00:00:00')", []tag{
{ {
name: "numeric", name: "default",
params: []string{"10", "2"}, params: []string{
"'2000-01-01 00:00:00'",
},
},
}, },
}, },
}, {
{"numeric(10, 2) notnull", []tag{ "json binary", []tag{
{ {
name: "numeric", name: "json",
params: []string{"10", "2"}, },
}, {
{ name: "binary",
name: "notnull", },
}, },
}, },
{
"numeric(10, 2)", []tag{
{
name: "numeric",
params: []string{"10", "2"},
},
},
},
{
"numeric(10, 2) notnull", []tag{
{
name: "numeric",
params: []string{"10", "2"},
},
{
name: "notnull",
},
},
},
{
"collate utf8mb4_bin", []tag{
{
name: "collate",
},
{
name: "utf8mb4_bin",
},
},
}, },
} }
@ -82,7 +98,7 @@ func TestSplitTag(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, len(tags), len(kase.tags)) assert.EqualValues(t, len(tags), len(kase.tags))
for i := 0; i < len(tags); i++ { for i := 0; i < len(tags); i++ {
assert.Equal(t, tags[i], kase.tags[i]) assert.Equal(t, kase.tags[i], tags[i])
} }
}) })
} }