Merge branch 'master' into lunny/dm_ci

This commit is contained in:
Lunny Xiao 2023-07-12 17:01:28 +08:00
commit ed1797d0c5
28 changed files with 1000 additions and 629 deletions

View File

@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s)
} }
@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
} }
} }
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
if _, err := b.WriteString(s); err != nil { if _, err := b.WriteString(s); err != nil {
return "", false, err return "", false, err
} }

View File

@ -85,8 +85,6 @@ type Dialect interface {
AddColumnSQL(tableName string, col *schemas.Column) string AddColumnSQL(tableName string, col *schemas.Column) string
ModifyColumnSQL(tableName string, col *schemas.Column) string ModifyColumnSQL(tableName string, col *schemas.Column) string
ForUpdateSQL(query string) string
Filters() []Filter Filters() []Filter
SetParams(params map[string]string) SetParams(params map[string]string)
} }
@ -135,7 +133,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {
@ -209,7 +207,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa
// AddColumnSQL returns a SQL to add a column // AddColumnSQL returns a SQL to add a column
func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s)
} }
@ -241,22 +239,15 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string {
// ModifyColumnSQL returns a SQL to modify SQL // ModifyColumnSQL returns a SQL to modify SQL
func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, false)
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
} }
// ForUpdateSQL returns for updateSQL
func (db *Base) ForUpdateSQL(query string) string {
return query + " FOR UPDATE"
}
// SetParams set params // SetParams set params
func (db *Base) SetParams(params map[string]string) { func (db *Base) SetParams(params map[string]string) {
} }
var ( var dialects = map[string]func() Dialect{}
dialects = map[string]func() Dialect{}
)
// RegisterDialect register database dialect // RegisterDialect register database dialect
func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) {
@ -307,7 +298,7 @@ func init() {
} }
// ColumnString generate column description string according dialect // ColumnString generate column description string according dialect
func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) {
bd := strings.Builder{} bd := strings.Builder{}
if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil {
@ -322,6 +313,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool)
return "", err return "", err
} }
if supportCollation && col.Collation != "" {
if _, err := bd.WriteString(" COLLATE "); err != nil {
return "", err
}
if _, err := bd.WriteString(col.Collation); err != nil {
return "", err
}
}
if includePrimaryKey && col.IsPrimaryKey { if includePrimaryKey && col.IsPrimaryKey {
if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil {
return "", err return "", err

View File

@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) {
} }
func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false) s, _ := ColumnString(db.dialect, col, false, true)
return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s)
} }
@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(p.is_primary_key, 0), a.is_identity as is_identity ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name
from sys.columns a from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id left join sys.syscomments c on a.default_object_id=c.id
@ -475,9 +475,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
var collation *string
var maxLen, precision, scale int64 var maxLen, precision, scale int64
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -499,6 +500,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.Length = maxLen col.Length = maxLen
} }
if collation != nil {
col.Collation = *collation
}
switch ct { switch ct {
case "DATETIMEOFFSET": case "DATETIMEOFFSET":
col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0}
@ -646,7 +650,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if i != len(table.ColumnsSeq())-1 { if i != len(table.ColumnsSeq())-1 {
@ -665,10 +669,6 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
return b.String(), true, nil return b.String(), true, nil
} }
func (db *mssql) ForUpdateSQL(query string) string {
return query
}
func (db *mssql) Filters() []Filter { func (db *mssql) Filters() []Filter {
return []Filter{} return []Filter{}
} }

View File

@ -38,6 +38,7 @@ var (
"CALL": true, "CALL": true,
"CASCADE": true, "CASCADE": true,
"CASE": true, "CASE": true,
"CHAIN": true,
"CHANGE": true, "CHANGE": true,
"CHAR": true, "CHAR": true,
"CHARACTER": true, "CHARACTER": true,
@ -128,6 +129,7 @@ var (
"OUT": true, "OUTER": true, "OUTFILE": true, "OUT": true, "OUTER": true, "OUTFILE": true,
"PRECISION": true, "PRIMARY": true, "PROCEDURE": true, "PRECISION": true, "PRIMARY": true, "PROCEDURE": true,
"PURGE": true, "RAID0": true, "RANGE": true, "PURGE": true, "RAID0": true, "RANGE": true,
"RANK": true,
"READ": true, "READS": true, "REAL": true, "READ": true, "READS": true, "REAL": true,
"REFERENCES": true, "REGEXP": true, "RELEASE": true, "REFERENCES": true, "REGEXP": true, "RELEASE": true,
"RENAME": true, "REPEAT": true, "REPLACE": true, "RENAME": true, "REPEAT": true, "REPLACE": true,
@ -380,7 +382,7 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa
func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
s, _ := ColumnString(db, col, true) s, _ := ColumnString(db, col, true, true)
var b strings.Builder var b strings.Builder
b.WriteString("ALTER TABLE ") b.WriteString("ALTER TABLE ")
quoter.QuoteTo(&b, tableName) quoter.QuoteTo(&b, tableName)
@ -394,6 +396,15 @@ func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string {
return b.String() return b.String()
} }
// ModifyColumnSQL returns a SQL to modify SQL
func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, false, true)
if col.Comment != "" {
s += fmt.Sprintf(" COMMENT '%s'", col.Comment)
}
return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s)
}
func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{db.uri.DBName, tableName} args := []interface{}{db.uri.DBName, tableName}
alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " + alreadyQuoted := "(INSTR(VERSION(), 'maria') > 0 && " +
@ -404,7 +415,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
"SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))"
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
@ -422,8 +433,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
var columnName, nullableStr, colType, colKey, extra, comment string var columnName, nullableStr, colType, colKey, extra, comment string
var alreadyQuoted, isUnsigned bool var alreadyQuoted, isUnsigned bool
var colDefault, maxLength *string var colDefault, maxLength, collation *string
err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -439,6 +450,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
col.DefaultIsEmpty = true col.DefaultIsEmpty = true
} }
if collation != nil {
col.Collation = *collation
}
fields := strings.Fields(colType) fields := strings.Fields(colType)
if len(fields) == 2 && fields[1] == "unsigned" { if len(fields) == 2 && fields[1] == "unsigned" {
@ -531,7 +545,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) {
args := []interface{}{db.uri.DBName} args := []interface{}{db.uri.DBName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " +
"`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
@ -543,9 +557,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
tables := make([]*schemas.Table, 0) tables := make([]*schemas.Table, 0)
for rows.Next() { for rows.Next() {
table := schemas.NewEmptyTable() table := schemas.NewEmptyTable()
var name, engine string var name, engine, collation string
var autoIncr, comment *string var autoIncr, comment *string
err = rows.Scan(&name, &engine, &autoIncr, &comment) err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -555,6 +569,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema
table.Comment = *comment table.Comment = *comment
} }
table.StoreEngine = engine table.StoreEngine = engine
table.Collation = collation
tables = append(tables, table) tables = append(tables, table)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -646,7 +661,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table
for i, colName := range table.ColumnsSeq() { for i, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName) col := table.GetColumn(colName)
s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true)
b.WriteString(s) b.WriteString(s)
if len(col.Comment) > 0 { if len(col.Comment) > 0 {

View File

@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
/*if col.IsPrimaryKey && len(pkList) == 1 { /*if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect) sql += col.String(b.dialect)
} else {*/ } else {*/
s, _ := ColumnString(db, col, false) s, _ := ColumnString(db, col, false, false)
sql += s sql += s
// } // }
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)

View File

@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl
} }
func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string {
s, _ := ColumnString(db.dialect, col, true) s, _ := ColumnString(db.dialect, col, true, false)
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
addColumnSQL := "" addColumnSQL := ""
@ -1078,7 +1078,7 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
schema := db.getSchema() schema := db.getSchema()
if schema != "" { if schema != "" {

View File

@ -193,11 +193,11 @@ func (db *sqlite3) Features() *DialectFeatures {
func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *sqlite3) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = sqlite3Quoter q := sqlite3Quoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = sqlite3Quoter q := sqlite3Quoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -291,10 +291,6 @@ func (db *sqlite3) DropIndexSQL(tableName string, index *schemas.Index) string {
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName))
} }
func (db *sqlite3) ForUpdateSQL(query string) string {
return query
}
func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) {
query := "SELECT * FROM " + tableName + " LIMIT 0" query := "SELECT * FROM " + tableName + " LIMIT 0"
rows, err := queryer.QueryContext(ctx, query) rows, err := queryer.QueryContext(ctx, query)
@ -320,7 +316,7 @@ func (db *sqlite3) IsColumnExist(queryer core.Queryer, ctx context.Context, tabl
// splitColStr splits a sqlite col strings as fields // splitColStr splits a sqlite col strings as fields
func splitColStr(colStr string) []string { func splitColStr(colStr string) []string {
colStr = strings.TrimSpace(colStr) colStr = strings.TrimSpace(colStr)
var results = make([]string, 0, 10) results := make([]string, 0, 10)
var lastIdx int var lastIdx int
var hasC, hasQuote bool var hasC, hasQuote bool
for i, c := range colStr { for i, c := range colStr {

View File

@ -1120,21 +1120,6 @@ func (engine *Engine) UnMapType(t reflect.Type) {
engine.tagParser.ClearCacheTable(t) engine.tagParser.ClearCacheTable(t)
} }
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.Sync(beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (engine *Engine) Sync2(beans ...interface{}) error {
return engine.Sync(beans...)
}
// CreateTables create tabls according bean // CreateTables create tabls according bean
func (engine *Engine) CreateTables(beans ...interface{}) error { func (engine *Engine) CreateTables(beans ...interface{}) error {
session := engine.NewSession() session := engine.NewSession()

View File

@ -289,6 +289,48 @@ func TestGetColumnsComment(t *testing.T) {
assert.Zero(t, noComment) assert.Zero(t, noComment)
} }
type TestCommentUpdate struct {
HasComment int `xorm:"bigint comment('this is a comment before update')"`
}
func (m *TestCommentUpdate) TableName() string {
return "test_comment_struct"
}
type TestCommentUpdate2 struct {
HasComment int `xorm:"bigint comment('this is a comment after update')"`
}
func (m *TestCommentUpdate2) TableName() string {
return "test_comment_struct"
}
func TestColumnCommentUpdate(t *testing.T) {
comment := "this is a comment after update"
assertSync(t, new(TestCommentUpdate))
assert.NoError(t, testEngine.Sync2(new(TestCommentUpdate2))) // modify table column comment
switch testEngine.Dialect().URI().DBType {
case schemas.POSTGRES, schemas.MYSQL: // only postgres / mysql dialect implement the feature of modify comment in postgres.ModifyColumnSQL
default:
t.Skip()
return
}
tables, err := testEngine.DBMetas()
assert.NoError(t, err)
tableName := "test_comment_struct"
var hasComment string
for _, table := range tables {
if table.Name == tableName {
col := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("HasComment"))
assert.NotNil(t, col)
hasComment = col.Comment
break
}
}
assert.Equal(t, comment, hasComment)
}
func TestGetColumnsLength(t *testing.T) { func TestGetColumnsLength(t *testing.T) {
var max_length int64 var max_length int64
switch testEngine.Dialect().URI().DBType { switch testEngine.Dialect().URI().DBType {

View File

@ -5,6 +5,7 @@
package integrations package integrations
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"testing" "testing"
@ -458,7 +459,7 @@ func TestSync2_2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
var tableNames = make(map[string]bool) tableNames := make(map[string]bool)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("test_sync2_index_%d", i) tableName := fmt.Sprintf("test_sync2_index_%d", i)
tableNames[tableName] = true tableNames[tableName] = true
@ -536,3 +537,111 @@ func TestModifyColum(t *testing.T) {
_, err := testEngine.Exec(alterSQL) _, err := testEngine.Exec(alterSQL)
assert.NoError(t, err) assert.NoError(t, err)
} }
type TestCollateColumn struct {
Id int64
UserId int64 `xorm:"unique(s)"`
Name string `xorm:"varchar(20) unique(s)"`
dbtype string `xorm:"-"`
}
func (t TestCollateColumn) TableCollations() []*schemas.Collation {
if t.dbtype == string(schemas.MYSQL) {
return []*schemas.Collation{
{
Name: "utf8mb4_general_ci",
Column: "name",
},
}
} else if t.dbtype == string(schemas.MSSQL) {
return []*schemas.Collation{
{
Name: "Latin1_General_CI_AS",
Column: "name",
},
}
}
return nil
}
func TestCollate(t *testing.T) {
assert.NoError(t, PrepareEngine())
assertSync(t, &TestCollateColumn{
dbtype: string(testEngine.Dialect().URI().DBType),
})
_, err := testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test",
})
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
ver, err1 := testEngine.DBVersion()
assert.NoError(t, err1)
tables, err1 := testEngine.DBMetas()
assert.NoError(t, err1)
for _, table := range tables {
if table.Name == "test_collate_column" {
col := table.GetColumn("name")
if col == nil {
assert.Error(t, errors.New("not found column"))
return
}
// tidb doesn't follow utf8mb4_general_ci
if col.Collation == "utf8mb4_general_ci" && ver.Edition != "TiDB" {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
break
}
}
} else if testEngine.Dialect().URI().DBType == schemas.MSSQL {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// Since SQLITE don't support modify column SQL, currrently just ignore
if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL {
return
}
var newCollation string
if testEngine.Dialect().URI().DBType == schemas.MYSQL {
newCollation = "utf8mb4_bin"
} else if testEngine.Dialect().URI().DBType != schemas.MSSQL {
newCollation = "Latin1_General_CS_AS"
} else {
return
}
alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{
Name: "name",
SQLType: schemas.SQLType{
Name: "VARCHAR",
},
Length: 20,
Nullable: true,
DefaultIsEmpty: true,
Collation: newCollation,
})
_, err = testEngine.Exec(alterSQL)
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "test1",
})
assert.NoError(t, err)
_, err = testEngine.Insert(&TestCollateColumn{
UserId: 1,
Name: "Test1",
})
assert.NoError(t, err)
}

View File

@ -89,7 +89,7 @@ func TestCountWithOthers(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
total, err := testEngine.OrderBy("`id` desc").Limit(1).Count(new(CountWithOthers)) total, err := testEngine.OrderBy("count(`id`) desc").Limit(1).Count(new(CountWithOthers))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
} }
@ -118,11 +118,11 @@ func TestWithTableName(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
total, err := testEngine.OrderBy("`id` desc").Count(new(CountWithTableName)) total, err := testEngine.OrderBy("count(`id`) desc").Count(new(CountWithTableName))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
total, err = testEngine.OrderBy("`id` desc").Count(CountWithTableName{}) total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
} }

View File

@ -12,6 +12,7 @@ import (
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/names" "xorm.io/xorm/names"
"xorm.io/xorm/schemas"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -1196,3 +1197,43 @@ func TestUpdateFindDate(t *testing.T) {
assert.EqualValues(t, 1, len(tufs)) assert.EqualValues(t, 1, len(tufs))
assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02")) assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02"))
} }
func TestBuilderDialect(t *testing.T) {
assert.NoError(t, PrepareEngine())
type TestBuilderDialect struct {
Id int64
Name string `xorm:"index"`
Age2 int
}
type TestBuilderDialectFoo struct {
Id int64
DialectId int64 `xorm:"index"`
Age int
}
assertSync(t, new(TestBuilderDialect), new(TestBuilderDialectFoo))
session := testEngine.NewSession()
defer session.Close()
var dialect string
switch testEngine.Dialect().URI().DBType {
case schemas.MYSQL:
dialect = builder.MYSQL
case schemas.MSSQL:
dialect = builder.MSSQL
case schemas.POSTGRES:
dialect = builder.POSTGRES
case schemas.SQLITE:
dialect = builder.SQLITE
}
tbName := testEngine.TableName(new(TestBuilderDialectFoo), dialect == builder.POSTGRES)
inner := builder.Dialect(dialect).Select("*").From(tbName).Where(builder.Eq{"age": 20})
result := make([]*TestBuilderDialect, 0, 10)
err := testEngine.Table("test_builder_dialect").Where(builder.Eq{"age2": 2}).Join("INNER", inner, "test_builder_dialect_foo.dialect_id = test_builder_dialect.id").Find(&result)
assert.NoError(t, err)
}

View File

@ -15,82 +15,97 @@ import (
) )
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition interface{}, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, joinTable interface{}, condition interface{}, args ...interface{}) *Statement {
var buf strings.Builder statement.joins = append(statement.joins, join{
if len(statement.JoinStr) > 0 { op: joinOP,
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) table: joinTable,
} else { condition: condition,
fmt.Fprintf(&buf, "%v JOIN ", joinOP) args: args,
} })
condStr := ""
condArgs := []interface{}{}
switch condTp := condition.(type) {
case string:
condStr = condTp
case builder.Cond:
var err error
condStr, condArgs, err = builder.ToSQL(condTp)
if err != nil {
statement.LastError = err
return statement
}
default:
statement.LastError = fmt.Errorf("unsupported join condition type: %v", condTp)
return statement return statement
} }
switch tp := tablename.(type) { func (statement *Statement) writeJoins(w *builder.BytesWriter) error {
case builder.Builder: for _, join := range statement.joins {
subSQL, subQueryArgs, err := tp.ToSQL() if err := statement.writeJoin(w, join); err != nil {
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
case *builder.Builder:
subSQL, subQueryArgs, err := tp.ToSQL()
if err != nil {
statement.LastError = err
return statement
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
fmt.Fprintf(&buf, "(%s) %s ON %v", statement.ReplaceQuote(subSQL), statement.quote(aliasName), statement.ReplaceQuote(condStr))
statement.joinArgs = append(append(statement.joinArgs, subQueryArgs...), condArgs...)
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true)
if !utils.IsSubQuery(tbName) {
var buf strings.Builder
_ = statement.dialect.Quoter().QuoteTo(&buf, tbName)
tbName = buf.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
fmt.Fprintf(&buf, "%s ON %v", tbName, statement.ReplaceQuote(condStr))
statement.joinArgs = append(statement.joinArgs, condArgs...)
}
statement.JoinStr = buf.String()
statement.joinArgs = append(statement.joinArgs, args...)
return statement
}
func (statement *Statement) writeJoin(w builder.Writer) error {
if statement.JoinStr != "" {
if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
return err return err
} }
w.Append(statement.joinArgs...)
} }
return nil return nil
} }
func (statement *Statement) writeJoin(buf *builder.BytesWriter, join join) error {
// write join operator
if _, err := fmt.Fprintf(buf, " %v JOIN", join.op); err != nil {
return err
}
// write join table or subquery
switch tp := join.table.(type) {
case builder.Builder:
if _, err := fmt.Fprintf(buf, " ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(buf)); err != nil {
return err
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
if _, err := fmt.Fprintf(buf, ") %s", statement.quote(aliasName)); err != nil {
return err
}
case *builder.Builder:
if _, err := fmt.Fprintf(buf, " ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(buf)); err != nil {
return err
}
fields := strings.Split(tp.TableName(), ".")
aliasName := statement.dialect.Quoter().Trim(fields[len(fields)-1])
aliasName = schemas.CommonQuoter.Trim(aliasName)
if _, err := fmt.Fprintf(buf, ") %s", statement.quote(aliasName)); err != nil {
return err
}
default:
tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), join.table, true)
if !utils.IsSubQuery(tbName) {
var sb strings.Builder
if err := statement.dialect.Quoter().QuoteTo(&sb, tbName); err != nil {
return err
}
tbName = sb.String()
} else {
tbName = statement.ReplaceQuote(tbName)
}
if _, err := fmt.Fprint(buf, " ", tbName); err != nil {
return err
}
}
// write on condition
if _, err := fmt.Fprint(buf, " ON "); err != nil {
return err
}
switch condTp := join.condition.(type) {
case string:
if _, err := fmt.Fprint(buf, statement.ReplaceQuote(condTp)); err != nil {
return err
}
case builder.Cond:
if err := condTp.WriteTo(statement.QuoteReplacer(buf)); err != nil {
return err
}
default:
return fmt.Errorf("unsupported join condition type: %v", condTp)
}
buf.Append(join.args...)
return nil
}

View File

@ -7,6 +7,7 @@ package statements
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"reflect" "reflect"
"strings" "strings"
@ -29,37 +30,15 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
if err := statement.ProcessIDParam(); err != nil { if err := statement.ProcessIDParam(); err != nil {
return "", nil, err return "", nil, err
} }
return statement.genSelectSQL(columnStr, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }
// GenSumSQL generates sum SQL // GenSumSQL generates sum SQL
@ -81,13 +60,16 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
} }
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
} }
sumSelect := strings.Join(sumStrs, ", ")
if err := statement.MergeConds(bean); err != nil { if err := statement.MergeConds(bean); err != nil {
return "", nil, err return "", nil, err
} }
return statement.genSelectSQL(sumSelect, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }
// GenGetSQL generates Get SQL // GenGetSQL generates Get SQL
@ -108,7 +90,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
columnStr = statement.SelectStr columnStr = statement.SelectStr
} else { } else {
// TODO: always generate column names, not use * even if join // TODO: always generate column names, not use * even if join
if len(statement.JoinStr) == 0 { if len(statement.joins) == 0 {
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.quoteColumnStr(statement.GroupByStr) columnStr = statement.quoteColumnStr(statement.GroupByStr)
@ -139,7 +121,11 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
return statement.genSelectSQL(columnStr, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, columnStr, true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }
// GenCountSQL generates the SQL for counting // GenCountSQL generates the SQL for counting
@ -148,8 +134,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
} }
var condArgs []interface{}
var err error
if len(beans) > 0 { if len(beans) > 0 {
if err := statement.SetRefBean(beans[0]); err != nil { if err := statement.SetRefBean(beans[0]); err != nil {
return "", nil, err return "", nil, err
@ -176,19 +160,27 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
subQuerySelect = selectSQL subQuerySelect = selectSQL
} }
sqlStr, condArgs, err := statement.genSelectSQL(subQuerySelect, false, false) buf := builder.NewWriter()
if err != nil { if statement.GroupByStr != "" {
if _, err := fmt.Fprintf(buf, "SELECT %s FROM (", selectSQL); err != nil {
return "", nil, err
}
}
if err := statement.writeSelect(buf, subQuerySelect, false); err != nil {
return "", nil, err return "", nil, err
} }
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) if _, err := fmt.Fprintf(buf, ") sub"); err != nil {
return "", nil, err
}
} }
return sqlStr, condArgs, nil return buf.String(), buf.Args(), nil
} }
func (statement *Statement) writeFrom(w builder.Writer) error { func (statement *Statement) writeFrom(w *builder.BytesWriter) error {
if _, err := fmt.Fprint(w, " FROM "); err != nil { if _, err := fmt.Fprint(w, " FROM "); err != nil {
return err return err
} }
@ -198,7 +190,7 @@ func (statement *Statement) writeFrom(w builder.Writer) error {
if err := statement.writeAlias(w); err != nil { if err := statement.writeAlias(w); err != nil {
return err return err
} }
return statement.writeJoin(w) return statement.writeJoins(w)
} }
func (statement *Statement) writeLimitOffset(w builder.Writer) error { func (statement *Statement) writeLimitOffset(w builder.Writer) error {
@ -218,37 +210,73 @@ func (statement *Statement) writeLimitOffset(w builder.Writer) error {
return nil return nil
} }
func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { func (statement *Statement) writeTop(w builder.Writer) error {
var ( if statement.dialect.URI().DBType != schemas.MSSQL {
distinct string return nil
dialect = statement.dialect }
top, whereStr string if statement.LimitN == nil {
mssqlCondi = builder.NewWriter() return nil
) }
_, err := fmt.Fprintf(w, " TOP %d", *statement.LimitN)
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { return err
distinct = "DISTINCT "
} }
condWriter := builder.NewWriter() func (statement *Statement) writeDistinct(w builder.Writer) error {
if err := statement.cond.WriteTo(statement.QuoteReplacer(condWriter)); err != nil { if statement.IsDistinct && !strings.HasPrefix(statement.SelectStr, "count(") {
return "", nil, err _, err := fmt.Fprint(w, " DISTINCT")
return err
}
return nil
} }
if condWriter.Len() > 0 { func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr string) error {
whereStr = " WHERE " if _, err := fmt.Fprintf(w, "SELECT "); err != nil {
return err
}
if err := statement.writeDistinct(w); err != nil {
return err
}
if err := statement.writeTop(w); err != nil {
return err
}
_, err := fmt.Fprint(w, " ", columnStr)
return err
} }
pLimitN := statement.LimitN func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
if dialect.URI().DBType == schemas.MSSQL { if !statement.cond.IsValid() {
if pLimitN != nil { return statement.writeMssqlPaginationCond(w)
LimitNValue := *pLimitN
top = fmt.Sprintf("TOP %d ", LimitNValue)
} }
if statement.Start > 0 { if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
return statement.writeMssqlPaginationCond(w)
}
func (statement *Statement) writeForUpdate(w io.Writer) error {
if !statement.IsForUpdate {
return nil
}
if statement.dialect.URI().DBType != schemas.MYSQL {
return errors.New("only support mysql for update")
}
_, err := fmt.Fprint(w, " FOR UPDATE")
return err
}
func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) error {
if statement.dialect.URI().DBType != schemas.MSSQL || statement.Start <= 0 {
return nil
}
if statement.RefTable == nil { if statement.RefTable == nil {
return "", nil, errors.New("Unsupported query limit without reference table") return errors.New("unsupported query limit without reference table")
} }
var column string var column string
if len(statement.RefTable.PKColumns()) == 0 { if len(statement.RefTable.PKColumns()) == 0 {
for _, index := range statement.RefTable.Indexes { for _, index := range statement.RefTable.Indexes {
@ -263,7 +291,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} else { } else {
column = statement.RefTable.PKColumns()[0].Name column = statement.RefTable.PKColumns()[0].Name
} }
if statement.needTableName() { if statement.NeedTableName() {
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
column = fmt.Sprintf("%s.%s", statement.TableAlias, column) column = fmt.Sprintf("%s.%s", statement.TableAlias, column)
} else { } else {
@ -271,100 +299,94 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", subWriter := builder.NewWriter()
if _, err := fmt.Fprintf(subWriter, "(%s NOT IN (SELECT TOP %d %s",
column, statement.Start, column); err != nil { column, statement.Start, column); err != nil {
return "", nil, err return err
} }
if err := statement.writeFrom(mssqlCondi); err != nil { if err := statement.writeFrom(subWriter); err != nil {
return "", nil, err return err
} }
if whereStr != "" { if statement.cond.IsValid() {
if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil { if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil {
return "", nil, err return err
} }
if err := utils.WriteBuilder(mssqlCondi, statement.QuoteReplacer(condWriter)); err != nil { if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil {
return "", nil, err return err
} }
} }
if needOrderBy { if err := statement.WriteOrderBy(subWriter); err != nil {
if err := statement.WriteOrderBy(mssqlCondi); err != nil { return err
return "", nil, err
}
}
if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err
}
if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil {
return "", nil, err
} }
if err := statement.writeGroupBy(subWriter); err != nil {
return err
} }
if _, err := fmt.Fprint(subWriter, "))"); err != nil {
return err
} }
buf := builder.NewWriter() if statement.cond.IsValid() {
if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil { if _, err := fmt.Fprint(w, " AND "); err != nil {
return "", nil, err return err
}
if err := statement.writeFrom(buf); err != nil {
return "", nil, err
}
if whereStr != "" {
if _, err := fmt.Fprint(buf, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(buf, statement.QuoteReplacer(condWriter)); err != nil {
return "", nil, err
}
}
if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 {
if _, err := fmt.Fprint(buf, " AND "); err != nil {
return "", nil, err
} }
} else { } else {
if _, err := fmt.Fprint(buf, " WHERE "); err != nil { if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return "", nil, err return err
} }
} }
if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { return utils.WriteBuilder(w, subWriter)
return "", nil, err
}
} }
if err := statement.WriteGroupBy(buf); err != nil { func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr string) error {
return "", nil, err if statement.LimitN == nil {
return nil
} }
if err := statement.writeHaving(buf); err != nil {
return "", nil, err oldString := w.String()
} w.Reset()
if needOrderBy {
if err := statement.WriteOrderBy(buf); err != nil {
return "", nil, err
}
}
if needLimit {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if err := statement.writeLimitOffset(buf); err != nil {
return "", nil, err
}
} else if dialect.URI().DBType == schemas.ORACLE {
if pLimitN != nil {
oldString := buf.String()
buf.Reset()
rawColStr := columnStr rawColStr := columnStr
if rawColStr == "*" { if rawColStr == "*" {
rawColStr = "at.*" rawColStr = "at.*"
} }
fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", _, err := fmt.Fprintf(w, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start) columnStr, rawColStr, oldString, statement.Start+*statement.LimitN, statement.Start)
} return err
}
}
if statement.IsForUpdate {
return dialect.ForUpdateSQL(buf.String()), buf.Args(), nil
} }
return buf.String(), buf.Args(), nil func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error {
if err := statement.writeSelectColumns(buf, columnStr); err != nil {
return err
}
if err := statement.writeFrom(buf); err != nil {
return err
}
if err := statement.writeWhere(buf); err != nil {
return err
}
if err := statement.writeGroupBy(buf); err != nil {
return err
}
if err := statement.writeHaving(buf); err != nil {
return err
}
if err := statement.WriteOrderBy(buf); err != nil {
return err
}
dialect := statement.dialect
if needLimit {
if dialect.URI().DBType == schemas.ORACLE {
if err := statement.writeOracleLimit(buf, columnStr); err != nil {
return err
}
} else if dialect.URI().DBType != schemas.MSSQL {
if err := statement.writeLimitOffset(buf); err != nil {
return err
}
}
}
return statement.writeForUpdate(buf)
} }
// GenExistSQL generates Exist SQL // GenExistSQL generates Exist SQL
@ -402,7 +424,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT TOP 1 * FROM %s", tableName); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.writeJoin(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
@ -417,7 +439,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.writeJoin(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
@ -438,7 +460,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT 1 FROM %s", tableName); err != nil {
return "", nil, err return "", nil, err
} }
if err := statement.writeJoin(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if statement.Conds().IsValid() {
@ -457,6 +479,33 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
} }
func (statement *Statement) genSelectColumnStr() string {
// manually select columns
if len(statement.SelectStr) > 0 {
return statement.SelectStr
}
columnStr := statement.ColumnStr()
if columnStr != "" {
return columnStr
}
// autodetect columns
if statement.GroupByStr != "" {
return statement.quoteColumnStr(statement.GroupByStr)
}
if len(statement.joins) != 0 {
return "*"
}
columnStr = statement.genColumnStr()
if columnStr == "" {
columnStr = "*"
}
return columnStr
}
// GenFindSQL generates Find SQL // GenFindSQL generates Find SQL
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
if statement.RawSQL != "" { if statement.RawSQL != "" {
@ -467,33 +516,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
return "", nil, ErrTableNotFound return "", nil, ErrTableNotFound
} }
columnStr := statement.ColumnStr()
if len(statement.SelectStr) > 0 {
columnStr = statement.SelectStr
} else {
if statement.JoinStr == "" {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = statement.genColumnStr()
}
}
} else {
if columnStr == "" {
if statement.GroupByStr != "" {
columnStr = statement.quoteColumnStr(statement.GroupByStr)
} else {
columnStr = "*"
}
}
}
if columnStr == "" {
columnStr = "*"
}
}
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
return statement.genSelectSQL(columnStr, true, true) buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
return "", nil, err
}
return buf.String(), buf.Args(), nil
} }

View File

@ -102,7 +102,7 @@ func (statement *Statement) genColumnStr() string {
buf.WriteString(", ") buf.WriteString(", ")
} }
if statement.JoinStr != "" { if len(statement.joins) > 0 {
if statement.TableAlias != "" { if statement.TableAlias != "" {
buf.WriteString(statement.TableAlias) buf.WriteString(statement.TableAlias)
} else { } else {
@ -119,7 +119,7 @@ func (statement *Statement) genColumnStr() string {
} }
func (statement *Statement) colName(col *schemas.Column, tableName string) string { func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() { if statement.NeedTableName() {
nm := tableName nm := tableName
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
nm = statement.TableAlias nm = statement.TableAlias

View File

@ -34,6 +34,13 @@ var (
ErrTableNotFound = errors.New("Table not found") ErrTableNotFound = errors.New("Table not found")
) )
type join struct {
op string
table interface{}
condition interface{}
args []interface{}
}
// Statement save all the sql info for executing SQL // Statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *schemas.Table RefTable *schemas.Table
@ -45,8 +52,7 @@ type Statement struct {
idParam schemas.PK idParam schemas.PK
orderStr string orderStr string
orderArgs []interface{} orderArgs []interface{}
JoinStr string joins []join
joinArgs []interface{}
GroupByStr string GroupByStr string
HavingStr string HavingStr string
SelectStr string SelectStr string
@ -123,8 +129,7 @@ func (statement *Statement) Reset() {
statement.LimitN = nil statement.LimitN = nil
statement.ResetOrderBy() statement.ResetOrderBy()
statement.UseCascade = true statement.UseCascade = true
statement.JoinStr = "" statement.joins = nil
statement.joinArgs = make([]interface{}, 0)
statement.GroupByStr = "" statement.GroupByStr = ""
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnMap = columnMap{} statement.ColumnMap = columnMap{}
@ -205,8 +210,8 @@ func (statement *Statement) SetRefBean(bean interface{}) error {
return nil return nil
} }
func (statement *Statement) needTableName() bool { func (statement *Statement) NeedTableName() bool {
return len(statement.JoinStr) > 0 return len(statement.joins) > 0
} }
// Incr Generate "Update ... Set column = column + arg" statement // Incr Generate "Update ... Set column = column + arg" statement
@ -290,7 +295,7 @@ func (statement *Statement) GroupBy(keys string) *Statement {
return statement return statement
} }
func (statement *Statement) WriteGroupBy(w builder.Writer) error { func (statement *Statement) writeGroupBy(w builder.Writer) error {
if statement.GroupByStr == "" { if statement.GroupByStr == "" {
return nil return nil
} }
@ -605,7 +610,7 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i
// MergeConds merge conditions from bean and id // MergeConds merge conditions from bean and id
func (statement *Statement) MergeConds(bean interface{}) error { func (statement *Statement) MergeConds(bean interface{}) error {
if !statement.NoAutoCondition && statement.RefTable != nil { if !statement.NoAutoCondition && statement.RefTable != nil {
addedTableName := (len(statement.JoinStr) > 0) addedTableName := (len(statement.joins) > 0)
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil { if err != nil {
return err return err
@ -673,7 +678,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
// CondDeleted returns the conditions whether a record is soft deleted. // CondDeleted returns the conditions whether a record is soft deleted.
func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
colName := statement.quote(col.Name) colName := statement.quote(col.Name)
if statement.JoinStr != "" { if len(statement.joins) > 0 {
var prefix string var prefix string
if statement.TableAlias != "" { if statement.TableAlias != "" {
prefix = statement.TableAlias prefix = statement.TableAlias

View File

@ -19,7 +19,8 @@ import (
) )
func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil, func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) (bool, error) { includeAutoIncr, update bool,
) (bool, error) {
columnMap := statement.ColumnMap columnMap := statement.ColumnMap
omitColumnMap := statement.OmitColumnMap omitColumnMap := statement.OmitColumnMap
unscoped := statement.unscoped unscoped := statement.unscoped
@ -64,15 +65,16 @@ func (statement *Statement) ifAddColUpdate(col *schemas.Column, includeVersion,
// BuildUpdates auto generating update columnes and values according a struct // BuildUpdates auto generating update columnes and values according a struct
func (statement *Statement) BuildUpdates(tableValue reflect.Value, func (statement *Statement) BuildUpdates(tableValue reflect.Value,
includeVersion, includeUpdated, includeNil, includeVersion, includeUpdated, includeNil,
includeAutoIncr, update bool) ([]string, []interface{}, error) { includeAutoIncr, update bool,
) ([]string, []interface{}, error) {
table := statement.RefTable table := statement.RefTable
allUseBool := statement.allUseBool allUseBool := statement.allUseBool
useAllCols := statement.useAllCols useAllCols := statement.useAllCols
mustColumnMap := statement.MustColumnMap mustColumnMap := statement.MustColumnMap
nullableMap := statement.NullableMap nullableMap := statement.NullableMap
var colNames = make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) args := make([]interface{}, 0)
for _, col := range table.Columns() { for _, col := range table.Columns() {
ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil, ok, err := statement.ifAddColUpdate(col, includeVersion, includeUpdated, includeNil,

10
rows.go
View File

@ -46,8 +46,8 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
if rows.session.statement.RawSQL == "" { if rows.session.statement.RawSQL == "" {
var autoCond builder.Cond var autoCond builder.Cond
var addedTableName = (len(session.statement.JoinStr) > 0) addedTableName := session.statement.NeedTableName()
var table = rows.session.statement.RefTable table := rows.session.statement.RefTable
if !session.statement.NoAutoCondition { if !session.statement.NoAutoCondition {
var err error var err error
@ -103,12 +103,12 @@ func (rows *Rows) Scan(beans ...interface{}) error {
return rows.Err() return rows.Err()
} }
var bean = beans[0] bean := beans[0]
var tp = reflect.TypeOf(bean) tp := reflect.TypeOf(bean)
if tp.Kind() == reflect.Ptr { if tp.Kind() == reflect.Ptr {
tp = tp.Elem() tp = tp.Elem()
} }
var beanKind = tp.Kind() beanKind := tp.Kind()
if len(beans) == 1 { if len(beans) == 1 {
if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType {

10
schemas/collation.go Normal file
View File

@ -0,0 +1,10 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package schemas
type Collation struct {
Name string
Column string // blank means it's a table collation
}

View File

@ -45,6 +45,7 @@ type Column struct {
DisableTimeZone bool DisableTimeZone bool
TimeZone *time.Location // column specified time zone TimeZone *time.Location // column specified time zone
Comment string Comment string
Collation string
} }
// NewColumn creates a new column // NewColumn creates a new column

View File

@ -27,6 +27,7 @@ type Table struct {
StoreEngine string StoreEngine string
Charset string Charset string
Comment string Comment string
Collation string
} }
// NewEmptyTable creates an empty table // NewEmptyTable creates an empty table
@ -36,7 +37,8 @@ func NewEmptyTable() *Table {
// NewTable creates a new Table object // NewTable creates a new Table object
func NewTable(name string, t reflect.Type) *Table { func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t, return &Table{
Name: name, Type: t,
columnsSeq: make([]string, 0), columnsSeq: make([]string, 0),
columns: make([]*Column, 0), columns: make([]*Column, 0),
columnsMap: make(map[string][]*Column), columnsMap: make(map[string][]*Column),

View File

@ -354,7 +354,7 @@ func (session *Session) DB() *core.DB {
func (session *Session) canCache() bool { func (session *Session) canCache() bool {
if session.statement.RefTable == nil || if session.statement.RefTable == nil ||
session.statement.JoinStr != "" || session.statement.NeedTableName() ||
session.statement.RawSQL != "" || session.statement.RawSQL != "" ||
!session.statement.UseCache || !session.statement.UseCache ||
session.statement.IsForUpdate || session.statement.IsForUpdate ||

View File

@ -114,7 +114,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
var ( var (
table = session.statement.RefTable table = session.statement.RefTable
addedTableName = (len(session.statement.JoinStr) > 0) addedTableName = session.statement.NeedTableName()
autoCond builder.Cond autoCond builder.Cond
) )
if tp == tpStruct { if tp == tpStruct {

View File

@ -15,7 +15,6 @@ import (
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
) )
// Ping test if database is ok // Ping test if database is ok
@ -169,7 +168,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
return nil return nil
} }
var seqName = utils.SeqName(tableName) seqName := utils.SeqName(tableName)
exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName) exist, err := session.engine.dialect.IsSequenceExist(session.ctx, session.getQueryer(), seqName)
if err != nil { if err != nil {
return err return err
@ -244,228 +243,6 @@ func (session *Session) addUnique(tableName, uqeName string) error {
return err return err
} }
// Sync2 synchronize structs to database tables
// Depricated
func (session *Session) Sync2(beans ...interface{}) error {
return session.Sync(beans...)
}
// Sync synchronize structs to database tables
func (session *Session) Sync(beans ...interface{}) error {
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil {
return err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
var tbName string
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
break
}
}
// this is a new table
if oriTable == nil {
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return err
}
err = session.createUniques(bean)
if err != nil {
return err
}
err = session.createIndexes(bean)
if err != nil {
return err
}
continue
}
// this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil {
return err
}
// check columns
for _, col := range table.Columns() {
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
break
}
}
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil {
return err
}
continue
}
err = nil
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType)
}
}
}
} else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else if col.Comment != oriCol.Comment {
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
if col.Default != oriCol.Default {
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default)
}
}
if col.Nullable != oriCol.Nullable {
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
if err != nil {
return err
}
}
var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
foundIndexNames[name2] = true
break
}
}
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
}
oriIndex = nil
}
}
if oriIndex == nil {
addedNames[name] = index
}
}
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
}
}
}
for name, index := range addedNames {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return err
}
}
// check all the columns which removed from struct fields but left on database tables.
for _, colName := range oriTable.ColumnsSeq() {
if table.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
}
}
}
return nil
}
// ImportFile SQL DDL file // ImportFile SQL DDL file
func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) {
file, err := os.Open(ddlPath) file, err := os.Open(ddlPath)
@ -490,7 +267,7 @@ func (session *Session) Import(r io.Reader) ([]sql.Result, error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
var oriInSingleQuote = inSingleQuote oriInSingleQuote := inSingleQuote
for i, b := range data { for i, b := range data {
if startComment { if startComment {
if b == '\n' { if b == '\n' {

276
sync.go Normal file
View File

@ -0,0 +1,276 @@
// Copyright 2023 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"strings"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
type SyncOptions struct {
WarnIfDatabaseColumnMissed bool
}
type SyncResult struct{}
// Sync the new struct changes to database, this method will automatically add
// table, column, index, unique. but will not delete or change anything.
// If you change some field, you should change the database manually.
func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.Sync(beans...)
}
// SyncWithOptions sync the database schemas according options and table structs
func (engine *Engine) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) {
session := engine.NewSession()
defer session.Close()
return session.SyncWithOptions(opts, beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (engine *Engine) Sync2(beans ...interface{}) error {
return engine.Sync(beans...)
}
// Sync2 synchronize structs to database tables
// Depricated
func (session *Session) Sync2(beans ...interface{}) error {
return session.Sync(beans...)
}
// Sync synchronize structs to database tables
func (session *Session) Sync(beans ...interface{}) error {
_, err := session.SyncWithOptions(SyncOptions{
WarnIfDatabaseColumnMissed: false,
}, beans...)
return err
}
func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) {
engine := session.engine
if session.isAutoClose {
session.isAutoClose = false
defer session.Close()
}
tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx)
if err != nil {
return nil, err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var syncResult SyncResult
for _, bean := range beans {
v := utils.ReflectValue(bean)
table, err := engine.tagParser.ParseWithCache(v)
if err != nil {
return nil, err
}
var tbName string
if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
}
tbNameWithSchema := engine.tbNameWithSchema(tbName)
var oriTable *schemas.Table
for _, tb := range tables {
if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
oriTable = tb
break
}
}
// this is a new table
if oriTable == nil {
err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
if err != nil {
return nil, err
}
err = session.createUniques(bean)
if err != nil {
return nil, err
}
err = session.createIndexes(bean)
if err != nil {
return nil, err
}
continue
}
// this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil {
return nil, err
}
// check columns
for _, col := range table.Columns() {
var oriCol *schemas.Column
for _, col2 := range oriTable.Columns() {
if strings.EqualFold(col.Name, col2.Name) {
oriCol = col2
break
}
}
// column is not exist on table
if oriCol == nil {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
if err = session.addColumn(col.Name); err != nil {
return nil, err
}
continue
}
err = nil
expectedType := engine.dialect.SQLType(col)
curType := engine.dialect.SQLType(oriCol)
if expectedType != curType {
if expectedType == schemas.Text &&
strings.HasPrefix(curType, schemas.Varchar) {
// currently only support mysql & postgres
if engine.dialect.URI().DBType == schemas.MYSQL ||
engine.dialect.URI().DBType == schemas.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbNameWithSchema, col.Name, curType, expectedType)
}
}
}
} else if expectedType == schemas.Varchar {
if engine.dialect.URI().DBType == schemas.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
} else if col.Comment != oriCol.Comment {
if engine.dialect.URI().DBType == schemas.POSTGRES ||
engine.dialect.URI().DBType == schemas.MYSQL {
_, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col))
}
}
if col.Default != oriCol.Default {
switch {
case col.IsAutoIncrement: // For autoincrement column, don't check default
case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) &&
((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
(strings.EqualFold(col.Default, "false") && oriCol.Default == "0")):
default:
engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
tbName, col.Name, oriCol.Default, col.Default)
}
}
if col.Nullable != oriCol.Nullable {
engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
tbName, col.Name, oriCol.Nullable, col.Nullable)
}
if err != nil {
return nil, err
}
}
foundIndexNames := make(map[string]bool)
addedNames := make(map[string]*schemas.Index)
for name, index := range table.Indexes {
var oriIndex *schemas.Index
for name2, index2 := range oriTable.Indexes {
if index.Equal(index2) {
oriIndex = index2
foundIndexNames[name2] = true
break
}
}
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return nil, err
}
oriIndex = nil
}
}
if oriIndex == nil {
addedNames[name] = index
}
}
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return nil, err
}
}
}
for name, index := range addedNames {
if index.Type == schemas.UniqueType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == schemas.IndexType {
session.statement.RefTable = table
session.statement.SetTableName(tbNameWithSchema)
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return nil, err
}
}
if opts.WarnIfDatabaseColumnMissed {
// check all the columns which removed from struct fields but left on database tables.
for _, colName := range oriTable.ColumnsSeq() {
if table.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
}
}
}
}
return &syncResult, nil
}

View File

@ -31,6 +31,12 @@ type TableIndices interface {
var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem()
type TableCollations interface {
TableCollations() []*schemas.Collation
}
var tpTableCollations = reflect.TypeOf((*TableCollations)(nil)).Elem()
// Parser represents a parser for xorm tag // Parser represents a parser for xorm tag
type Parser struct { type Parser struct {
identifier string identifier string
@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
} }
} }
collations := tableCollations(v)
for _, collation := range collations {
if collation.Name == "" {
continue
}
if collation.Column == "" {
table.Collation = collation.Name
} else {
col := table.GetColumn(collation.Column)
if col == nil {
return nil, ErrUnsupportedType
}
col.Collation = collation.Name // this may override definition in struct tag
}
}
return table, nil return table, nil
} }
@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index {
} }
return nil return nil
} }
func tableCollations(v reflect.Value) []*schemas.Collation {
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
if v.Type().Implements(tpTableCollations) {
return v.Interface().(TableCollations).TableCollations()
}
} else if v.CanAddr() {
v1 := v.Addr()
if v1.Type().Implements(tpTableCollations) {
return v1.Interface().(TableCollations).TableCollations()
}
}
return nil
}

View File

@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{
"COMMENT": CommentTagHandler, "COMMENT": CommentTagHandler,
"EXTENDS": ExtendsTagHandler, "EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler, "UNSIGNED": UnsignedTagHandler,
"COLLATE": CollateTagHandler,
} }
func init() { func init() {
@ -282,6 +283,16 @@ func CommentTagHandler(ctx *Context) error {
return nil return nil
} }
func CollateTagHandler(ctx *Context) error {
if len(ctx.params) > 0 {
ctx.col.Collation = ctx.params[0]
} else {
ctx.col.Collation = ctx.nextTag
ctx.ignoreNext = true
}
return nil
}
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}

View File

@ -11,11 +11,12 @@ import (
) )
func TestSplitTag(t *testing.T) { func TestSplitTag(t *testing.T) {
var cases = []struct { cases := []struct {
tag string tag string
tags []tag tags []tag
}{ }{
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ {
"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
{ {
name: "not", name: "not",
}, },
@ -33,13 +34,15 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"TEXT", []tag{ {
"TEXT", []tag{
{ {
name: "TEXT", name: "TEXT",
}, },
}, },
}, },
{"default('2000-01-01 00:00:00')", []tag{ {
"default('2000-01-01 00:00:00')", []tag{
{ {
name: "default", name: "default",
params: []string{ params: []string{
@ -48,7 +51,8 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"json binary", []tag{ {
"json binary", []tag{
{ {
name: "json", name: "json",
}, },
@ -57,14 +61,16 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{"numeric(10, 2)", []tag{ {
"numeric(10, 2)", []tag{
{ {
name: "numeric", name: "numeric",
params: []string{"10", "2"}, params: []string{"10", "2"},
}, },
}, },
}, },
{"numeric(10, 2) notnull", []tag{ {
"numeric(10, 2) notnull", []tag{
{ {
name: "numeric", name: "numeric",
params: []string{"10", "2"}, params: []string{"10", "2"},
@ -74,6 +80,16 @@ func TestSplitTag(t *testing.T) {
}, },
}, },
}, },
{
"collate utf8mb4_bin", []tag{
{
name: "collate",
},
{
name: "utf8mb4_bin",
},
},
},
} }
for _, kase := range cases { for _, kase := range cases {
@ -82,7 +98,7 @@ func TestSplitTag(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, len(tags), len(kase.tags)) assert.EqualValues(t, len(tags), len(kase.tags))
for i := 0; i < len(tags); i++ { for i := 0; i < len(tags); i++ {
assert.Equal(t, tags[i], kase.tags[i]) assert.Equal(t, kase.tags[i], tags[i])
} }
}) })
} }