From c5ee68faa17ca08a1ada8f0715ea385e2c5769dc Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 30 Sep 2019 15:09:57 +0800 Subject: [PATCH] Fix wrong dbmetas (#1442) * add tests for db metas * add more tests * fix bug on mssql --- dialect_mssql.go | 7 ++--- dialect_sqlite3.go | 67 +++++++++++++++++++++++++++------------------- tag_test.go | 49 +++++++++++++++++++++++++++++++++ xorm_test.go | 8 +++++- 4 files changed, 99 insertions(+), 32 deletions(-) diff --git a/dialect_mssql.go b/dialect_mssql.go index 524d05a4..29070da2 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -340,7 +340,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, "default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END), replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, - ISNULL(i.is_primary_key, 0) + ISNULL(i.is_primary_key, 0), a.is_identity as is_identity from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id left join sys.syscomments c on a.default_object_id=c.id @@ -362,8 +362,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column for rows.Next() { var name, ctype, vdefault string var maxLen, precision, scale int - var nullable, isPK, defaultIsNull bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK) + var nullable, isPK, defaultIsNull, isIncrement bool + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) if err != nil { return nil, nil, err } @@ -377,6 +377,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column col.Default = vdefault } col.IsPrimaryKey = isPK + col.IsAutoIncrement = isIncrement ct := strings.ToUpper(ctype) if ct == "DECIMAL" { col.Length = precision diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index d1852e9b..0a290f3c 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -298,6 +298,40 @@ func splitColStr(colStr string) []string { return results } +func parseString(colStr string) (*core.Column, error) { + fields := splitColStr(colStr) + col := new(core.Column) + col.Indexes = make(map[string]int) + col.Nullable = true + col.DefaultIsEmpty = true + + for idx, field := range fields { + if idx == 0 { + col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`) + continue + } else if idx == 1 { + col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} + continue + } + switch field { + case "PRIMARY": + col.IsPrimaryKey = true + case "AUTOINCREMENT": + col.IsAutoIncrement = true + case "NULL": + if fields[idx-1] == "NOT" { + col.Nullable = false + } else { + col.Nullable = true + } + case "DEFAULT": + col.Default = fields[idx+1] + col.DefaultIsEmpty = false + } + } + return col, nil +} + func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" @@ -327,6 +361,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu colCreates := reg.FindAllString(name[nStart+1:nEnd], -1) cols := make(map[string]*core.Column) colSeq := make([]string, 0) + for _, colStr := range colCreates { reg = regexp.MustCompile(`,\s`) colStr = reg.ReplaceAllString(colStr, ",") @@ -343,35 +378,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu continue } - fields := splitColStr(colStr) - col := new(core.Column) - col.Indexes = make(map[string]int) - col.Nullable = true - col.DefaultIsEmpty = true - - for idx, field := range fields { - if idx == 0 { - col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`) - continue - } else if idx == 1 { - col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0} - } - switch field { - case "PRIMARY": - col.IsPrimaryKey = true - case "AUTOINCREMENT": - col.IsAutoIncrement = true - case "NULL": - if fields[idx-1] == "NOT" { - col.Nullable = false - } else { - col.Nullable = true - } - case "DEFAULT": - col.Default = fields[idx+1] - col.DefaultIsEmpty = false - } + col, err := parseString(colStr) + if err != nil { + return colSeq, cols, err } + cols[col.Name] = col colSeq = append(colSeq, col.Name) } diff --git a/tag_test.go b/tag_test.go index 891c6ffc..979ba929 100644 --- a/tag_test.go +++ b/tag_test.go @@ -549,3 +549,52 @@ func TestSplitTag(t *testing.T) { } } } + +func TestTagAutoIncr(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TagAutoIncr struct { + Id int64 + Name string + } + + assertSync(t, new(TagAutoIncr)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name) + col := tables[0].GetColumn(colMapper.Obj2Table("Id")) + assert.NotNil(t, col) + assert.True(t, col.IsPrimaryKey) + assert.True(t, col.IsAutoIncrement) + + col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) + assert.NotNil(t, col2) + assert.False(t, col2.IsPrimaryKey) + assert.False(t, col2.IsAutoIncrement) +} + +func TestTagPrimarykey(t *testing.T) { + assert.NoError(t, prepareEngine()) + type TagPrimaryKey struct { + Id int64 `xorm:"pk"` + Name string `xorm:"VARCHAR(20) pk"` + } + + assertSync(t, new(TagPrimaryKey)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name) + col := tables[0].GetColumn(colMapper.Obj2Table("Id")) + assert.NotNil(t, col) + assert.True(t, col.IsPrimaryKey) + assert.False(t, col.IsAutoIncrement) + + col2 := tables[0].GetColumn(colMapper.Obj2Table("Name")) + assert.NotNil(t, col2) + assert.True(t, col2.IsPrimaryKey) + assert.False(t, col2.IsAutoIncrement) +} diff --git a/xorm_test.go b/xorm_test.go index c0302df3..21715256 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -15,10 +15,10 @@ import ( _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" - "xorm.io/core" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" _ "github.com/ziutek/mymysql/godrv" + "xorm.io/core" ) var ( @@ -35,6 +35,9 @@ var ( splitter = flag.String("splitter", ";", "the splitter on connstr for cluster") schema = flag.String("schema", "", "specify the schema") ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb") + + tableMapper core.IMapper + colMapper core.IMapper ) func createEngine(dbType, connStr string) error { @@ -122,6 +125,9 @@ func createEngine(dbType, connStr string) error { } } + tableMapper = testEngine.GetTableMapper() + colMapper = testEngine.GetColumnMapper() + tables, err := testEngine.DBMetas() if err != nil { return err