diff --git a/convert/interface.go b/convert/interface.go index b0f28c81..2cc8d9f4 100644 --- a/convert/interface.go +++ b/convert/interface.go @@ -24,7 +24,10 @@ func Interface2Interface(userLocation *time.Location, v interface{}) (interface{ return vv.String, nil case *sql.RawBytes: if len([]byte(*vv)) > 0 { - return []byte(*vv), nil + src := []byte(*vv) + dest := make([]byte, len(src)) + copy(dest, src) + return dest, nil } return nil, nil case *sql.NullInt32: diff --git a/dialects/dameng.go b/dialects/dameng.go index 5e92ec2f..23d1836a 100644 --- a/dialects/dameng.go +++ b/dialects/dameng.go @@ -659,7 +659,7 @@ func (db *dameng) DropTableSQL(tableName string) (string, bool) { // ModifyColumnSQL returns a SQL to modify SQL func (db *dameng) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s MODIFY %s", db.quoter.Quote(tableName), s) } @@ -692,7 +692,7 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl } } - s, _ := ColumnString(db, col, false) + s, _ := ColumnString(db, col, false, false) if _, err := b.WriteString(s); err != nil { return "", false, err } @@ -709,7 +709,13 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl return "", false, err } } - if _, err := b.WriteString(fmt.Sprintf("CONSTRAINT PK_%s PRIMARY KEY (", tableName)); err != nil { + if _, err := b.WriteString("CONSTRAINT PK_"); err != nil { + return "", false, err + } + if _, err := b.WriteString(tableName); err != nil { + return "", false, err + } + if _, err := b.WriteString(" PRIMARY KEY ("); err != nil { return "", false, err } if err := quoter.JoinWrite(&b, pkList, ","); err != nil { @@ -837,7 +843,11 @@ func addSingleQuote(name string) string { if name[0] == '\'' && name[len(name)-1] == '\'' { return name } - return fmt.Sprintf("'%s'", name) + var b strings.Builder + b.WriteRune('\'') + b.WriteString(name) + b.WriteRune('\'') + return b.String() } func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { diff --git a/dialects/dialect.go b/dialects/dialect.go index 70d599e6..d1c5f200 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -135,7 +135,7 @@ func (db *Base) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, false) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { @@ -209,7 +209,7 @@ func (db *Base) IsColumnExist(queryer core.Queryer, ctx context.Context, tableNa // AddColumnSQL returns a SQL to add a column func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, true) + s, _ := ColumnString(db.dialect, col, true, false) return fmt.Sprintf("ALTER TABLE %s ADD %s", db.dialect.Quoter().Quote(tableName), s) } @@ -241,7 +241,7 @@ func (db *Base) DropIndexSQL(tableName string, index *schemas.Index) string { // ModifyColumnSQL returns a SQL to modify SQL func (db *Base) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, false) return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } @@ -254,9 +254,7 @@ func (db *Base) ForUpdateSQL(query string) string { func (db *Base) SetParams(params map[string]string) { } -var ( - dialects = map[string]func() Dialect{} -) +var dialects = map[string]func() Dialect{} // RegisterDialect register database dialect func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { @@ -307,7 +305,7 @@ func init() { } // ColumnString generate column description string according dialect -func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) (string, error) { +func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey, supportCollation bool) (string, error) { bd := strings.Builder{} if err := dialect.Quoter().QuoteTo(&bd, col.Name); err != nil { @@ -322,6 +320,15 @@ func ColumnString(dialect Dialect, col *schemas.Column, includePrimaryKey bool) return "", err } + if supportCollation && col.Collation != "" { + if _, err := bd.WriteString(" COLLATE "); err != nil { + return "", err + } + if _, err := bd.WriteString(col.Collation); err != nil { + return "", err + } + } + if includePrimaryKey && col.IsPrimaryKey { if _, err := bd.WriteString(" PRIMARY KEY"); err != nil { return "", err diff --git a/dialects/filter.go b/dialects/filter.go index ff9a44e8..add67c1b 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -6,7 +6,7 @@ package dialects import ( "context" - "fmt" + "strconv" "strings" ) @@ -29,10 +29,11 @@ func convertQuestionMark(sql, prefix string, start int) string { var isMaybeLineComment bool var isMaybeComment bool var isMaybeCommentEnd bool - var index = start + index := start for _, c := range sql { if !beginSingleQuote && !isLineComment && !isComment && c == '?' { - buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) + buf.WriteString(prefix) + buf.WriteString(strconv.Itoa(index)) index++ } else { if isMaybeLineComment { diff --git a/dialects/mssql.go b/dialects/mssql.go index 1b6fe692..e517e688 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -428,7 +428,7 @@ func (db *mssql) DropTableSQL(tableName string) (string, bool) { } func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, false) + s, _ := ColumnString(db.dialect, col, false, true) return fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s", db.quoter.Quote(tableName), s) } @@ -454,7 +454,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, - ISNULL(p.is_primary_key, 0), a.is_identity as is_identity + ISNULL(p.is_primary_key, 0), a.is_identity as is_identity, a.collation_name from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id left join sys.syscomments c on a.default_object_id=c.id @@ -475,9 +475,10 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName colSeq := make([]string, 0) for rows.Next() { var name, ctype, vdefault string + var collation *string var maxLen, precision, scale int64 var nullable, isPK, defaultIsNull, isIncrement bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement, &collation) if err != nil { return nil, nil, err } @@ -499,6 +500,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } else { col.Length = maxLen } + if collation != nil { + col.Collation = *collation + } switch ct { case "DATETIMEOFFSET": col.SQLType = schemas.SQLType{Name: schemas.TimeStampz, DefaultLength: 0, DefaultLength2: 0} @@ -646,7 +650,7 @@ func (db *mssql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) if i != len(table.ColumnsSeq())-1 { diff --git a/dialects/mysql.go b/dialects/mysql.go index 195e1f23..5663d1dd 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -380,12 +380,24 @@ func (db *mysql) IsTableExist(queryer core.Queryer, ctx context.Context, tableNa func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { quoter := db.dialect.Quoter() - s, _ := ColumnString(db, col, true) - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), s) + s, _ := ColumnString(db, col, true, true) + var b strings.Builder + b.WriteString("ALTER TABLE ") + quoter.QuoteTo(&b, tableName) + b.WriteString(" ADD ") + b.WriteString(s) if len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" + b.WriteString(" COMMENT '") + b.WriteString(col.Comment) + b.WriteString("'") } - return sql + return b.String() +} + +// ModifyColumnSQL returns a SQL to modify SQL +func (db *mysql) ModifyColumnSQL(tableName string, col *schemas.Column) string { + s, _ := ColumnString(db.dialect, col, false, true) + return fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s", db.quoter.Quote(tableName), s) } func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -398,7 +410,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + - alreadyQuoted + " AS NEEDS_QUOTE " + + alreadyQuoted + " AS NEEDS_QUOTE, `COLLATION_NAME` " + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" @@ -416,8 +428,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName var columnName, nullableStr, colType, colKey, extra, comment string var alreadyQuoted, isUnsigned bool - var colDefault, maxLength *string - err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted) + var colDefault, maxLength, collation *string + err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted, &collation) if err != nil { return nil, nil, err } @@ -433,6 +445,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } else { col.DefaultIsEmpty = true } + if collation != nil { + col.Collation = *collation + } fields := strings.Fields(colType) if len(fields) == 2 && fields[1] == "unsigned" { @@ -525,7 +540,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{db.uri.DBName} - s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + + s := "SELECT `TABLE_NAME`, `ENGINE`, `AUTO_INCREMENT`, `TABLE_COMMENT`, `TABLE_COLLATION` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" rows, err := queryer.QueryContext(ctx, s, args...) @@ -537,9 +552,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { table := schemas.NewEmptyTable() - var name, engine string + var name, engine, collation string var autoIncr, comment *string - err = rows.Scan(&name, &engine, &autoIncr, &comment) + err = rows.Scan(&name, &engine, &autoIncr, &comment, &collation) if err != nil { return nil, err } @@ -549,6 +564,7 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Comment = *comment } table.StoreEngine = engine + table.Collation = collation tables = append(tables, table) } if rows.Err() != nil { @@ -640,7 +656,7 @@ func (db *mysql) CreateTableSQL(ctx context.Context, queryer core.Queryer, table for i, colName := range table.ColumnsSeq() { col := table.GetColumn(colName) - s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1) + s, _ := ColumnString(db.dialect, col, col.IsPrimaryKey && len(table.PrimaryKeys) == 1, true) b.WriteString(s) if len(col.Comment) > 0 { diff --git a/dialects/oracle.go b/dialects/oracle.go index b0c5c38f..a5f8a5b2 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -628,7 +628,7 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl /*if col.IsPrimaryKey && len(pkList) == 1 { sql += col.String(b.dialect) } else {*/ - s, _ := ColumnString(db, col, false) + s, _ := ColumnString(db, col, false, false) sql += s // } sql = strings.TrimSpace(sql) diff --git a/dialects/postgres.go b/dialects/postgres.go index 5efe54f4..28196891 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -992,7 +992,7 @@ func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tabl } func (db *postgres) AddColumnSQL(tableName string, col *schemas.Column) string { - s, _ := ColumnString(db.dialect, col, true) + s, _ := ColumnString(db.dialect, col, true, false) quoter := db.dialect.Quoter() addColumnSQL := "" @@ -1078,7 +1078,7 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` +WHERE n.nspname= s.table_schema AND c.relkind = 'r' AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` schema := db.getSchema() if schema != "" { diff --git a/engine.go b/engine.go index 389819e7..aa9d8050 100644 --- a/engine.go +++ b/engine.go @@ -1120,21 +1120,6 @@ func (engine *Engine) UnMapType(t reflect.Type) { engine.tagParser.ClearCacheTable(t) } -// Sync the new struct changes to database, this method will automatically add -// table, column, index, unique. but will not delete or change anything. -// If you change some field, you should change the database manually. -func (engine *Engine) Sync(beans ...interface{}) error { - session := engine.NewSession() - defer session.Close() - return session.Sync(beans...) -} - -// Sync2 synchronize structs to database tables -// Depricated -func (engine *Engine) Sync2(beans ...interface{}) error { - return engine.Sync(beans...) -} - // CreateTables create tabls according bean func (engine *Engine) CreateTables(beans ...interface{}) error { session := engine.NewSession() diff --git a/integrations/session_schema_test.go b/integrations/schema_test.go similarity index 85% rename from integrations/session_schema_test.go rename to integrations/schema_test.go index 3212d027..149c6394 100644 --- a/integrations/session_schema_test.go +++ b/integrations/schema_test.go @@ -5,6 +5,7 @@ package integrations import ( + "errors" "fmt" "strings" "testing" @@ -325,14 +326,14 @@ func TestIsTableEmpty(t *testing.T) { type PictureEmpty struct { Id int64 - Url string `xorm:"unique"` //image's url + Url string `xorm:"unique"` // image's url Title string Description string Created time.Time `xorm:"created"` ILike int PageView int From_url string // nolint - Pre_url string `xorm:"unique"` //pre view image's url + Pre_url string `xorm:"unique"` // pre view image's url Uid int64 } @@ -458,7 +459,7 @@ func TestSync2_2(t *testing.T) { assert.NoError(t, PrepareEngine()) - var tableNames = make(map[string]bool) + tableNames := make(map[string]bool) for i := 0; i < 10; i++ { tableName := fmt.Sprintf("test_sync2_index_%d", i) tableNames[tableName] = true @@ -536,3 +537,111 @@ func TestModifyColum(t *testing.T) { _, err := testEngine.Exec(alterSQL) assert.NoError(t, err) } + +type TestCollateColumn struct { + Id int64 + UserId int64 `xorm:"unique(s)"` + Name string `xorm:"varchar(20) unique(s)"` + dbtype string `xorm:"-"` +} + +func (t TestCollateColumn) TableCollations() []*schemas.Collation { + if t.dbtype == string(schemas.MYSQL) { + return []*schemas.Collation{ + { + Name: "utf8mb4_general_ci", + Column: "name", + }, + } + } else if t.dbtype == string(schemas.MSSQL) { + return []*schemas.Collation{ + { + Name: "Latin1_General_CI_AS", + Column: "name", + }, + } + } + return nil +} + +func TestCollate(t *testing.T) { + assert.NoError(t, PrepareEngine()) + assertSync(t, &TestCollateColumn{ + dbtype: string(testEngine.Dialect().URI().DBType), + }) + + _, err := testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "test", + }) + assert.NoError(t, err) + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "Test", + }) + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + ver, err1 := testEngine.DBVersion() + assert.NoError(t, err1) + + tables, err1 := testEngine.DBMetas() + assert.NoError(t, err1) + for _, table := range tables { + if table.Name == "test_collate_column" { + col := table.GetColumn("name") + if col == nil { + assert.Error(t, errors.New("not found column")) + return + } + // tidb doesn't follow utf8mb4_general_ci + if col.Collation == "utf8mb4_general_ci" && ver.Edition != "TiDB" { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + break + } + } + } else if testEngine.Dialect().URI().DBType == schemas.MSSQL { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Since SQLITE don't support modify column SQL, currrently just ignore + if testEngine.Dialect().URI().DBType != schemas.MYSQL && testEngine.Dialect().URI().DBType != schemas.MSSQL { + return + } + + var newCollation string + if testEngine.Dialect().URI().DBType == schemas.MYSQL { + newCollation = "utf8mb4_bin" + } else if testEngine.Dialect().URI().DBType != schemas.MSSQL { + newCollation = "Latin1_General_CS_AS" + } else { + return + } + + alterSQL := testEngine.Dialect().ModifyColumnSQL("test_collate_column", &schemas.Column{ + Name: "name", + SQLType: schemas.SQLType{ + Name: "VARCHAR", + }, + Length: 20, + Nullable: true, + DefaultIsEmpty: true, + Collation: newCollation, + }) + _, err = testEngine.Exec(alterSQL) + assert.NoError(t, err) + + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "test1", + }) + assert.NoError(t, err) + _, err = testEngine.Insert(&TestCollateColumn{ + UserId: 1, + Name: "Test1", + }) + assert.NoError(t, err) +} diff --git a/integrations/session_query_test.go b/integrations/session_query_test.go index b72f7ef2..00b7d7a6 100644 --- a/integrations/session_query_test.go +++ b/integrations/session_query_test.go @@ -5,11 +5,13 @@ package integrations import ( + "bytes" "strconv" "testing" "time" "xorm.io/builder" + "xorm.io/xorm/schemas" "github.com/stretchr/testify/assert" @@ -381,3 +383,68 @@ func TestQueryStringWithLimit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 0, len(data)) } + +func TestQueryBLOBInMySQL(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + var err error + type Avatar struct { + Id int64 `xorm:"autoincr pk"` + Avatar []byte `xorm:"BLOB"` + } + + assert.NoError(t, testEngine.Sync(new(Avatar))) + testEngine.Delete(Avatar{}) + + repeatBytes := func(n int, b byte) []byte { + return bytes.Repeat([]byte{b}, n) + } + + const N = 10 + var data = []Avatar{} + for i := 0; i < N; i++ { + // allocate a []byte that is as twice big as the last one + // so that the underlying buffer will need to reallocate when querying + bs := repeatBytes(1<<(i+2), 'A'+byte(i)) + data = append(data, Avatar{ + Avatar: bs, + }) + } + _, err = testEngine.Insert(data) + assert.NoError(t, err) + defer func() { + testEngine.Delete(Avatar{}) + }() + + { + records, err := testEngine.QueryInterface("select avatar from " + testEngine.Quote(testEngine.TableName("avatar", true))) + assert.NoError(t, err) + for i, record := range records { + bs := record["avatar"].([]byte) + assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3]) + t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2]) + } + } + + { + arr := make([][]interface{}, 0) + err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr) + assert.NoError(t, err) + for i, record := range arr { + bs := record[0].([]byte) + assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3]) + t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2]) + } + } + + { + arr := make([]map[string]interface{}, 0) + err = testEngine.Table(testEngine.Quote(testEngine.TableName("avatar", true))).Cols("avatar").Find(&arr) + assert.NoError(t, err) + for i, record := range arr { + bs := record["avatar"].([]byte) + assert.EqualValues(t, repeatBytes(1<<(i+2), 'A'+byte(i))[:3], bs[:3]) + t.Logf("%d => %p => %02x %02x %02x", i, bs, bs[0], bs[1], bs[2]) + } + } +} diff --git a/internal/statements/cache.go b/internal/statements/cache.go index 669cd018..9dd76754 100644 --- a/internal/statements/cache.go +++ b/internal/statements/cache.go @@ -6,6 +6,7 @@ package statements import ( "fmt" + "strconv" "strings" "xorm.io/xorm/internal/utils" @@ -26,14 +27,19 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string { return "" } - var top string + var b strings.Builder + b.WriteString("SELECT ") pLimitN := statement.LimitN if pLimitN != nil && statement.dialect.URI().DBType == schemas.MSSQL { - top = fmt.Sprintf("TOP %d ", *pLimitN) + b.WriteString("TOP ") + b.WriteString(strconv.Itoa(*pLimitN)) + b.WriteString(" ") } + b.WriteString(colstrs) + b.WriteString(" FROM ") + b.WriteString(sqls[1]) - newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1]) - return newsql + return b.String() } return "" } @@ -54,7 +60,7 @@ func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { return "", "" } - var whereStr = sqls[1] + whereStr := sqls[1] // TODO: for postgres only, if any other database? var paraStr string diff --git a/schemas/collation.go b/schemas/collation.go new file mode 100644 index 00000000..acec5268 --- /dev/null +++ b/schemas/collation.go @@ -0,0 +1,10 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package schemas + +type Collation struct { + Name string + Column string // blank means it's a table collation +} diff --git a/schemas/column.go b/schemas/column.go index 5a579e92..08d34b91 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -45,6 +45,7 @@ type Column struct { DisableTimeZone bool TimeZone *time.Location // column specified time zone Comment string + Collation string } // NewColumn creates a new column diff --git a/schemas/table.go b/schemas/table.go index 91b33e06..5c38cc70 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -27,6 +27,7 @@ type Table struct { StoreEngine string Charset string Comment string + Collation string } // NewEmptyTable creates an empty table @@ -36,7 +37,8 @@ func NewEmptyTable() *Table { // NewTable creates a new Table object func NewTable(name string, t reflect.Type) *Table { - return &Table{Name: name, Type: t, + return &Table{ + Name: name, Type: t, columnsSeq: make([]string, 0), columns: make([]*Column, 0), columnsMap: make(map[string][]*Column), diff --git a/sync.go b/sync.go new file mode 100644 index 00000000..6583e341 --- /dev/null +++ b/sync.go @@ -0,0 +1,273 @@ +// Copyright 2023 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "strings" + + "xorm.io/xorm/internal/utils" + "xorm.io/xorm/schemas" +) + +type SyncOptions struct { + WarnIfDatabaseColumnMissed bool +} + +type SyncResult struct{} + +// Sync the new struct changes to database, this method will automatically add +// table, column, index, unique. but will not delete or change anything. +// If you change some field, you should change the database manually. +func (engine *Engine) Sync(beans ...interface{}) error { + session := engine.NewSession() + defer session.Close() + return session.Sync(beans...) +} + +// SyncWithOptions sync the database schemas according options and table structs +func (engine *Engine) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) { + session := engine.NewSession() + defer session.Close() + return session.SyncWithOptions(opts, beans...) +} + +// Sync2 synchronize structs to database tables +// Depricated +func (engine *Engine) Sync2(beans ...interface{}) error { + return engine.Sync(beans...) +} + +// Sync2 synchronize structs to database tables +// Depricated +func (session *Session) Sync2(beans ...interface{}) error { + return session.Sync(beans...) +} + +// Sync synchronize structs to database tables +func (session *Session) Sync(beans ...interface{}) error { + _, err := session.SyncWithOptions(SyncOptions{ + WarnIfDatabaseColumnMissed: false, + }, beans...) + return err +} + +func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) { + engine := session.engine + + if session.isAutoClose { + session.isAutoClose = false + defer session.Close() + } + + tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) + if err != nil { + return nil, err + } + + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + session.resetStatement() + }() + + var syncResult SyncResult + + for _, bean := range beans { + v := utils.ReflectValue(bean) + table, err := engine.tagParser.ParseWithCache(v) + if err != nil { + return nil, err + } + var tbName string + if len(session.statement.AltTableName) > 0 { + tbName = session.statement.AltTableName + } else { + tbName = engine.TableName(bean) + } + tbNameWithSchema := engine.tbNameWithSchema(tbName) + + var oriTable *schemas.Table + for _, tb := range tables { + if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { + oriTable = tb + break + } + } + + // this is a new table + if oriTable == nil { + err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) + if err != nil { + return nil, err + } + + err = session.createUniques(bean) + if err != nil { + return nil, err + } + + err = session.createIndexes(bean) + if err != nil { + return nil, err + } + continue + } + + // this will modify an old table + if err = engine.loadTableInfo(oriTable); err != nil { + return nil, err + } + + // check columns + for _, col := range table.Columns() { + var oriCol *schemas.Column + for _, col2 := range oriTable.Columns() { + if strings.EqualFold(col.Name, col2.Name) { + oriCol = col2 + break + } + } + + // column is not exist on table + if oriCol == nil { + session.statement.RefTable = table + session.statement.SetTableName(tbNameWithSchema) + if err = session.addColumn(col.Name); err != nil { + return nil, err + } + continue + } + + err = nil + expectedType := engine.dialect.SQLType(col) + curType := engine.dialect.SQLType(oriCol) + if expectedType != curType { + if expectedType == schemas.Text && + strings.HasPrefix(curType, schemas.Varchar) { + // currently only support mysql & postgres + if engine.dialect.URI().DBType == schemas.MYSQL || + engine.dialect.URI().DBType == schemas.POSTGRES { + engine.logger.Infof("Table %s column %s change type from %s to %s\n", + tbNameWithSchema, col.Name, curType, expectedType) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } else { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", + tbNameWithSchema, col.Name, curType, expectedType) + } + } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { + if engine.dialect.URI().DBType == schemas.MYSQL { + if oriCol.Length < col.Length { + engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } + } + } else { + if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { + if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", + tbNameWithSchema, col.Name, curType, expectedType) + } + } + } + } else if expectedType == schemas.Varchar { + if engine.dialect.URI().DBType == schemas.MYSQL { + if oriCol.Length < col.Length { + engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } + } + } else if col.Comment != oriCol.Comment { + _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) + } + + if col.Default != oriCol.Default { + switch { + case col.IsAutoIncrement: // For autoincrement column, don't check default + case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && + ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || + (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): + default: + engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", + tbName, col.Name, oriCol.Default, col.Default) + } + } + if col.Nullable != oriCol.Nullable { + engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v", + tbName, col.Name, oriCol.Nullable, col.Nullable) + } + + if err != nil { + return nil, err + } + } + + foundIndexNames := make(map[string]bool) + addedNames := make(map[string]*schemas.Index) + + for name, index := range table.Indexes { + var oriIndex *schemas.Index + for name2, index2 := range oriTable.Indexes { + if index.Equal(index2) { + oriIndex = index2 + foundIndexNames[name2] = true + break + } + } + + if oriIndex != nil { + if oriIndex.Type != index.Type { + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) + _, err = session.exec(sql) + if err != nil { + return nil, err + } + oriIndex = nil + } + } + + if oriIndex == nil { + addedNames[name] = index + } + } + + for name2, index2 := range oriTable.Indexes { + if _, ok := foundIndexNames[name2]; !ok { + sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) + _, err = session.exec(sql) + if err != nil { + return nil, err + } + } + } + + for name, index := range addedNames { + if index.Type == schemas.UniqueType { + session.statement.RefTable = table + session.statement.SetTableName(tbNameWithSchema) + err = session.addUnique(tbNameWithSchema, name) + } else if index.Type == schemas.IndexType { + session.statement.RefTable = table + session.statement.SetTableName(tbNameWithSchema) + err = session.addIndex(tbNameWithSchema, name) + } + if err != nil { + return nil, err + } + } + + if opts.WarnIfDatabaseColumnMissed { + // check all the columns which removed from struct fields but left on database tables. + for _, colName := range oriTable.ColumnsSeq() { + if table.GetColumn(colName) == nil { + engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName) + } + } + } + } + + return &syncResult, nil +} diff --git a/tags/parser.go b/tags/parser.go index 028f8d0b..53ef0c10 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -31,6 +31,12 @@ type TableIndices interface { var tpTableIndices = reflect.TypeOf((*TableIndices)(nil)).Elem() +type TableCollations interface { + TableCollations() []*schemas.Collation +} + +var tpTableCollations = reflect.TypeOf((*TableCollations)(nil)).Elem() + // Parser represents a parser for xorm tag type Parser struct { identifier string @@ -356,6 +362,22 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { } } + collations := tableCollations(v) + for _, collation := range collations { + if collation.Name == "" { + continue + } + if collation.Column == "" { + table.Collation = collation.Name + } else { + col := table.GetColumn(collation.Column) + if col == nil { + return nil, ErrUnsupportedType + } + col.Collation = collation.Name // this may override definition in struct tag + } + } + return table, nil } @@ -377,3 +399,22 @@ func tableIndices(v reflect.Value) []*schemas.Index { } return nil } + +func tableCollations(v reflect.Value) []*schemas.Collation { + if v.Type().Implements(tpTableCollations) { + return v.Interface().(TableCollations).TableCollations() + } + + if v.Kind() == reflect.Ptr { + v = v.Elem() + if v.Type().Implements(tpTableCollations) { + return v.Interface().(TableCollations).TableCollations() + } + } else if v.CanAddr() { + v1 := v.Addr() + if v1.Type().Implements(tpTableCollations) { + return v1.Interface().(TableCollations).TableCollations() + } + } + return nil +} diff --git a/tags/tag.go b/tags/tag.go index 41d525e1..024c9c18 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -123,6 +123,7 @@ var defaultTagHandlers = map[string]Handler{ "COMMENT": CommentTagHandler, "EXTENDS": ExtendsTagHandler, "UNSIGNED": UnsignedTagHandler, + "COLLATE": CollateTagHandler, } func init() { @@ -282,6 +283,16 @@ func CommentTagHandler(ctx *Context) error { return nil } +func CollateTagHandler(ctx *Context) error { + if len(ctx.params) > 0 { + ctx.col.Collation = ctx.params[0] + } else { + ctx.col.Collation = ctx.nextTag + ctx.ignoreNext = true + } + return nil +} + // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname} diff --git a/tags/tag_test.go b/tags/tag_test.go index 3ceeefd1..6c456f2a 100644 --- a/tags/tag_test.go +++ b/tags/tag_test.go @@ -11,68 +11,84 @@ import ( ) func TestSplitTag(t *testing.T) { - var cases = []struct { + cases := []struct { tag string tags []tag }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ - { - name: "not", - }, - { - name: "null", - }, - { - name: "default", - }, - { - name: "'2000-01-01 00:00:00'", - }, - { - name: "TIMESTAMP", - }, - }, - }, - {"TEXT", []tag{ - { - name: "TEXT", - }, - }, - }, - {"default('2000-01-01 00:00:00')", []tag{ - { - name: "default", - params: []string{ - "'2000-01-01 00:00:00'", + { + "not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{ + { + name: "not", + }, + { + name: "null", + }, + { + name: "default", + }, + { + name: "'2000-01-01 00:00:00'", + }, + { + name: "TIMESTAMP", }, }, }, - }, - {"json binary", []tag{ - { - name: "json", - }, - { - name: "binary", + { + "TEXT", []tag{ + { + name: "TEXT", + }, }, }, - }, - {"numeric(10, 2)", []tag{ - { - name: "numeric", - params: []string{"10", "2"}, + { + "default('2000-01-01 00:00:00')", []tag{ + { + name: "default", + params: []string{ + "'2000-01-01 00:00:00'", + }, + }, }, }, - }, - {"numeric(10, 2) notnull", []tag{ - { - name: "numeric", - params: []string{"10", "2"}, - }, - { - name: "notnull", + { + "json binary", []tag{ + { + name: "json", + }, + { + name: "binary", + }, }, }, + { + "numeric(10, 2)", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + }, + }, + { + "numeric(10, 2) notnull", []tag{ + { + name: "numeric", + params: []string{"10", "2"}, + }, + { + name: "notnull", + }, + }, + }, + { + "collate utf8mb4_bin", []tag{ + { + name: "collate", + }, + { + name: "utf8mb4_bin", + }, + }, }, } @@ -82,7 +98,7 @@ func TestSplitTag(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, len(tags), len(kase.tags)) for i := 0; i < len(tags); i++ { - assert.Equal(t, tags[i], kase.tags[i]) + assert.Equal(t, kase.tags[i], tags[i]) } }) }