diff --git a/schemas/table.go b/schemas/table.go index 5c38cc70..e30bde11 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -178,3 +178,17 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) { } return PK(pk), nil } + +type TableList []*Table + +// GetByName returns a table by name with case insensitive +// If not found, return nil +// If multiple tables with same name, return the first one +func (tables TableList) GetByName(name string) *Table { + for _, table := range tables { + if strings.EqualFold(table.Name, name) { + return table + } + } + return nil +} diff --git a/sync.go b/sync.go index b8b827da..cb2d5d37 100644 --- a/sync.go +++ b/sync.go @@ -5,6 +5,7 @@ package xorm import ( + "strconv" "strings" "xorm.io/xorm/internal/utils" @@ -62,6 +63,91 @@ func (session *Session) Sync(beans ...interface{}) error { return err } +func (session *Session) syncColumn(tbNameWithSchema string, expectedCol, oriCol *schemas.Column) error { + engine := session.engine + expectedType := engine.dialect.SQLType(expectedCol) + curType := engine.dialect.SQLType(oriCol) + var canExecuteModify bool + if expectedType != curType { + if expectedType == schemas.Text && + strings.HasPrefix(curType, schemas.Varchar) { + // currently only support mysql & postgres + if engine.dialect.URI().DBType == schemas.MYSQL || + engine.dialect.URI().DBType == schemas.POSTGRES { + engine.logger.Infof("Table %s column %s change type from %s to %s\n", + tbNameWithSchema, expectedCol.Name, curType, expectedType) + canExecuteModify = true + } else { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", + tbNameWithSchema, expectedCol.Name, curType, expectedType) + } + } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { + if engine.dialect.URI().DBType == schemas.POSTGRES || + engine.dialect.URI().DBType == schemas.MYSQL { + if oriCol.Length < expectedCol.Length { + engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + tbNameWithSchema, expectedCol.Name, oriCol.Length, expectedCol.Length) + canExecuteModify = true + } + } + } else { + if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { + if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { + engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", + tbNameWithSchema, expectedCol.Name, curType, expectedType) + } + } + } + } else { + if expectedType == schemas.Varchar { + if engine.dialect.URI().DBType == schemas.POSTGRES || + engine.dialect.URI().DBType == schemas.MYSQL { + if oriCol.Length < expectedCol.Length { + engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + tbNameWithSchema, expectedCol.Name, oriCol.Length, expectedCol.Length) + canExecuteModify = true + } + } + } + + if expectedCol.Comment != oriCol.Comment { + if engine.dialect.URI().DBType == schemas.POSTGRES || + engine.dialect.URI().DBType == schemas.GBASE8S || + engine.dialect.URI().DBType == schemas.MYSQL { + canExecuteModify = true + } + } + + if expectedCol.Nullable != oriCol.Nullable { + canExecuteModify = true + } + + if expectedCol.Default != oriCol.Default { + switch { + case expectedCol.IsAutoIncrement: // For autoincrement column, don't check default + case expectedCol.SQLType.Name == schemas.Bool || expectedCol.SQLType.Name == schemas.Boolean: + expectDefault, _ := strconv.ParseBool(expectedCol.Default) + oriDefault, _ := strconv.ParseBool(oriCol.Default) + if expectDefault != oriDefault { + engine.logger.Warnf("Table %s column %s db default is %s, struct default is %s", + tbNameWithSchema, expectedCol.Name, oriCol.Default, expectedCol.Default) + } + default: + engine.logger.Warnf("Table %s column %s db default is %s, struct default is %s", + tbNameWithSchema, expectedCol.Name, oriCol.Default, expectedCol.Default) + } + } + } + + if canExecuteModify { + if _, err := session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, expectedCol)); err != nil { + return err + } + } + + return nil +} + func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) (*SyncResult, error) { engine := session.engine @@ -136,13 +222,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) // check columns for _, col := range table.Columns() { - var oriCol *schemas.Column - for _, col2 := range oriTable.Columns() { - if strings.EqualFold(col.Name, col2.Name) { - oriCol = col2 - break - } - } + oriCol := oriTable.GetColumn(col.Name) // column is not exist on table if oriCol == nil { @@ -154,73 +234,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) continue } - err = nil - expectedType := engine.dialect.SQLType(col) - curType := engine.dialect.SQLType(oriCol) - if expectedType != curType { - if expectedType == schemas.Text && - strings.HasPrefix(curType, schemas.Varchar) { - // currently only support mysql & postgres - if engine.dialect.URI().DBType == schemas.MYSQL || - engine.dialect.URI().DBType == schemas.POSTGRES { - engine.logger.Infof("Table %s column %s change type from %s to %s\n", - tbNameWithSchema, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } else { - engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", - tbNameWithSchema, col.Name, curType, expectedType) - } - } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { - if engine.dialect.URI().DBType == schemas.POSTGRES || - engine.dialect.URI().DBType == schemas.MYSQL { - if oriCol.Length < col.Length { - engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } - } - } else { - if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { - if !strings.EqualFold(schemas.SQLTypeName(curType), engine.dialect.Alias(schemas.SQLTypeName(expectedType))) { - engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbNameWithSchema, col.Name, curType, expectedType) - } - } - } - } else if expectedType == schemas.Varchar { - if engine.dialect.URI().DBType == schemas.POSTGRES || - engine.dialect.URI().DBType == schemas.MYSQL { - if oriCol.Length < col.Length { - engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbNameWithSchema, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } - } - } else if col.Comment != oriCol.Comment { - if engine.dialect.URI().DBType == schemas.POSTGRES || - engine.dialect.URI().DBType == schemas.GBASE8S || - engine.dialect.URI().DBType == schemas.MYSQL { - _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) - } - } - - if col.Default != oriCol.Default { - switch { - case col.IsAutoIncrement: // For autoincrement column, don't check default - case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && - ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || - (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): - default: - engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", - tbName, col.Name, oriCol.Default, col.Default) - } - } - if col.Nullable != oriCol.Nullable { - engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v", - tbName, col.Name, oriCol.Nullable, col.Nullable) - } - - if err != nil { + if err := session.syncColumn(tbNameWithSchema, col, oriCol); err != nil { return nil, err } } diff --git a/tests/sync_test.go b/tests/sync_test.go index dedd3343..0d39e7d4 100644 --- a/tests/sync_test.go +++ b/tests/sync_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "xorm.io/xorm/schemas" ) type TestSync1 struct { @@ -32,3 +33,170 @@ func TestSync(t *testing.T) { assert.NoError(t, testEngine.Sync(new(TestSync1))) assert.NoError(t, testEngine.Sync(new(TestSync2))) } + +// Test Sync with varchar size changed +type TestSync3 struct { + Id int64 + Name string `xorm:"varchar(100)"` +} + +func (TestSync3) TableName() string { + return "test_sync_2" +} + +type TestSync4 struct { + Id int64 + Name string `xorm:"varchar(200)"` +} + +func (TestSync4) TableName() string { + return "test_sync_2" +} + +func Test_SyncVarcharSizeChange(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.SQLITE { + t.Skip("SQLite does not support column change") + } + + assert.NoError(t, testEngine.Sync(new(TestSync3))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + var testTable *schemas.Table + for _, table := range tables { + if table.Name == "test_sync_2" { + testTable = table + break + } + } + assert.NotNil(t, testTable) + assert.Len(t, testTable.Columns(), 2) + assert.Equal(t, "varchar", testTable.GetColumn("name").SQLType.Name) + assert.Equal(t, 100, testTable.GetColumn("name").Length) + + assert.NoError(t, testEngine.Sync(new(TestSync4))) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + testTable = nil + for _, table := range tables { + if table.Name == "test_sync_2" { + testTable = table + break + } + } + assert.NotNil(t, testTable) + assert.Len(t, testTable.Columns(), 2) + assert.Equal(t, "varchar", testTable.GetColumn("name").SQLType.Name) + assert.Equal(t, 200, testTable.GetColumn("name").Length) +} + +// Test Sync with varchar size changed +type TestSync5 struct { + Id int64 + Name string `xorm:"NOT NULL"` +} + +func (TestSync5) TableName() string { + return "test_sync_3" +} + +type TestSync6 struct { + Id int64 + Name string `xorm:"NULL"` +} + +func (TestSync6) TableName() string { + return "test_sync_3" +} + +func Test_SyncVarcharNullableChanged(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.SQLITE { + t.Skip("SQLite does not support column change") + } + + assert.NoError(t, testEngine.Sync(new(TestSync5))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + var testTable *schemas.Table + for _, table := range tables { + if table.Name == "test_sync_3" { + testTable = table + break + } + } + assert.NotNil(t, testTable) + assert.Len(t, testTable.Columns(), 2) + assert.False(t, testTable.GetColumn("name").Nullable) + + assert.NoError(t, testEngine.Sync(new(TestSync6))) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + testTable = nil + for _, table := range tables { + if table.Name == "test_sync_3" { + testTable = table + break + } + } + assert.NotNil(t, testTable) + assert.Len(t, testTable.Columns(), 2) + assert.True(t, testTable.GetColumn("name").Nullable) +} + +// Test Sync with varchar size changed +type TestSync7 struct { + Id int64 + Name string `xorm:"DEFAULT '1'"` +} + +func (TestSync7) TableName() string { + return "test_sync_4" +} + +type TestSync8 struct { + Id int64 + Name string `xorm:"DEFAULT '2'"` +} + +func (TestSync8) TableName() string { + return "test_sync_4" +} + +func Test_SyncVarcharDefaultChange(t *testing.T) { + if testEngine.Dialect().URI().DBType == schemas.SQLITE { + t.Skip("SQLite does not support column change") + } + + assert.NoError(t, testEngine.Sync(new(TestSync7))) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + var testTable *schemas.Table + for _, table := range tables { + if table.Name == "test_sync_4" { + testTable = table + break + } + } + assert.NotNil(t, testTable) + assert.Len(t, testTable.Columns(), 2) + assert.Equal(t, "1", testTable.GetColumn("name").Default) + + assert.NoError(t, testEngine.Sync(new(TestSync8))) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + testTable = nil + for _, table := range tables { + if table.Name == "test_sync_4" { + testTable = table + break + } + } + assert.NotNil(t, testTable) + assert.Len(t, testTable.Columns(), 2) + assert.Equal(t, "2", testTable.GetColumn("name").Default) +}