diff --git a/circle.yml b/circle.yml index 7e49116a..3063ac9d 100644 --- a/circle.yml +++ b/circle.yml @@ -21,7 +21,7 @@ database: test: override: # './...' is a relative pattern which means all subdirectories - - go test -v -race -db="sqlite3;mysql;postgres" -conn_str="./test.db;root:@/xorm_test;dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic + - go test -v -race -db="sqlite3::mysql::postgres" -conn_str="./test.db::root:@/xorm_test::dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh diff --git a/dialect_mssql.go b/dialect_mssql.go index f83cfc17..6d2291dc 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -215,7 +215,7 @@ func (db *mssql) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { case core.Bool: - res = core.TinyInt + res = core.Bit if strings.EqualFold(c.Default, "true") { c.Default = "1" } else { @@ -250,6 +250,9 @@ func (db *mssql) SqlType(c *core.Column) string { case core.Uuid: res = core.Varchar c.Length = 40 + case core.TinyInt: + res = core.TinyInt + c.Length = 0 default: res = t } @@ -335,9 +338,15 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable, - replace(replace(isnull(c.text,''),'(',''),')','') as vdefault - from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id - left join sys.syscomments c on a.default_object_id=c.id + replace(replace(isnull(c.text,''),'(',''),')','') as vdefault, + ISNULL(i.is_primary_key, 0) + from sys.columns a + left join sys.types b on a.user_type_id=b.user_type_id + left join sys.syscomments c on a.default_object_id=c.id + LEFT OUTER JOIN + sys.index_columns ic ON ic.object_id = a.object_id AND ic.column_id = a.column_id + LEFT OUTER JOIN + sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id where a.object_id=object_id('` + tableName + `')` db.LogSQL(s, args) @@ -352,8 +361,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column for rows.Next() { var name, ctype, vdefault string var maxLen, precision, scale int - var nullable bool - err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault) + var nullable, isPK bool + err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &vdefault, &isPK) if err != nil { return nil, nil, err } @@ -363,6 +372,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column col.Name = strings.Trim(name, "` ") col.Nullable = nullable col.Default = vdefault + col.IsPrimaryKey = isPK ct := strings.ToUpper(ctype) if ct == "DECIMAL" { col.Length = precision @@ -536,7 +546,6 @@ type odbcDriver struct { func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { kv := strings.Split(dataSourceName, ";") var dbName string - for _, c := range kv { vv := strings.Split(strings.TrimSpace(c), "=") if len(vv) == 2 { diff --git a/engine.go b/engine.go index 3195e47e..b2b2f97b 100644 --- a/engine.go +++ b/engine.go @@ -1512,9 +1512,14 @@ func (engine *Engine) NowTime2(sqlTypeName string) (interface{}, time.Time) { } func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) { - if col.DisableTimeZone { - return engine.formatTime(col.SQLType.Name, t) - } else if col.TimeZone != nil { + if t.IsZero() { + if col.Nullable { + return nil + } + return "" + } + + if col.TimeZone != nil { return engine.formatTime(col.SQLType.Name, t.In(col.TimeZone)) } return engine.formatTime(col.SQLType.Name, t.In(engine.DatabaseTZ)) diff --git a/session_cols_test.go b/session_cols_test.go index 9ee75904..33105281 100644 --- a/session_cols_test.go +++ b/session_cols_test.go @@ -7,6 +7,7 @@ package xorm import ( "testing" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -26,7 +27,11 @@ func TestSetExpr(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - cnt, err = testEngine.SetExpr("show", "NOT `show`").Id(1).Update(new(User)) + var not = "NOT" + if testEngine.dialect.DBType() == core.MSSQL { + not = "~" + } + cnt, err = testEngine.SetExpr("show", not+" `show`").Id(1).Update(new(User)) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } diff --git a/session_convert.go b/session_convert.go index 1616547e..df44ace7 100644 --- a/session_convert.go +++ b/session_convert.go @@ -586,11 +586,6 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val case reflect.Struct: if fieldType.ConvertibleTo(core.TimeType) { t := fieldValue.Convert(core.TimeType).Interface().(time.Time) - if session.Engine.dialect.DBType() == core.MSSQL { - if t.IsZero() { - return nil, nil - } - } tf := session.Engine.formatColTime(col, t) return tf, nil } diff --git a/session_get_test.go b/session_get_test.go index a8806ced..335d2570 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -119,6 +120,11 @@ func TestGetStruct(t *testing.T) { assert.NoError(t, testEngine.Sync(new(UserinfoGet))) + var err error + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT userinfo_get ON") + assert.NoError(t, err) + } cnt, err := testEngine.Insert(&UserinfoGet{Uid: 2}) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) diff --git a/session_insert.go b/session_insert.go index 93168c2e..c3648171 100644 --- a/session_insert.go +++ b/session_insert.go @@ -357,10 +357,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.Engine.QuoteStr(), colPlaces) } else { - if session.Engine.dialect.DBType() == core.SQLITE || session.Engine.dialect.DBType() == core.POSTGRES { - sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.Engine.Quote(session.Statement.TableName())) - } else { + if session.Engine.dialect.DBType() == core.MYSQL { sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.Engine.Quote(session.Statement.TableName())) + } else { + sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.Engine.Quote(session.Statement.TableName())) } } diff --git a/session_update.go b/session_update.go index 1d77d294..7cb38c22 100644 --- a/session_update.go +++ b/session_update.go @@ -298,7 +298,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 condSQL = "WHERE " + condSQL } } else if st.Engine.dialect.DBType() == core.MSSQL { - top = fmt.Sprintf("top (%d) ", st.LimitN) + if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL && + table != nil && len(table.PrimaryKeys) == 1 { + cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", + table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], + session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...) + + condSQL, condArgs, _ = builder.ToSQL(cond) + if len(condSQL) > 0 { + condSQL = "WHERE " + condSQL + } + } else { + top = fmt.Sprintf("TOP (%d) ", st.LimitN) + } } } diff --git a/session_update_test.go b/session_update_test.go index 46df78ca..2d41aceb 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -356,7 +356,6 @@ func TestUpdate1(t *testing.T) { And("departname = ?", ""). And("detail_id = ?", 0). And("is_man = ?", 0). - And("created IS NOT NULL"). Get(&Userinfo{}) if err != nil { t.Error(err) diff --git a/statement_test.go b/statement_test.go index 0cfd7f3e..594aa4f3 100644 --- a/statement_test.go +++ b/statement_test.go @@ -26,7 +26,7 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" { + if dbType == "postgres" || dbType == "mssql" { return } diff --git a/tag_extends_test.go b/tag_extends_test.go index c497b8d4..4b94fd4b 100644 --- a/tag_extends_test.go +++ b/tag_extends_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -328,6 +329,11 @@ func TestExtends2(t *testing.T) { Uid: sender.Id, ToUid: receiver.Id, } + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } + _, err = testEngine.Insert(&msg) if err != nil { t.Error(err) @@ -395,6 +401,10 @@ func TestExtends3(t *testing.T) { Uid: sender.Id, ToUid: receiver.Id, } + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } _, err = testEngine.Insert(&msg) if err != nil { t.Error(err) @@ -478,6 +488,10 @@ func TestExtends4(t *testing.T) { Content: "test", Uid: sender.Id, } + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT message ON") + assert.NoError(t, err) + } _, err = testEngine.Insert(&msg) if err != nil { t.Error(err) diff --git a/tag_id_test.go b/tag_id_test.go index 7a6d2c8a..be5f7337 100644 --- a/tag_id_test.go +++ b/tag_id_test.go @@ -38,7 +38,7 @@ func TestGonicMapperID(t *testing.T) { for _, tb := range tables { if tb.Name == "id_gonic_mapper" { - if len(tb.PKColumns()) != 1 && !tb.PKColumns()[0].IsPrimaryKey && !tb.PKColumns()[0].IsPrimaryKey { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { t.Fatal(tb) } return @@ -75,7 +75,7 @@ func TestSameMapperID(t *testing.T) { for _, tb := range tables { if tb.Name == "IDSameMapper" { - if len(tb.PKColumns()) != 1 && !tb.PKColumns()[0].IsPrimaryKey && !tb.PKColumns()[0].IsPrimaryKey { + if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "id" { t.Fatal(tb) } return diff --git a/test_mssql.sh b/test_mssql.sh new file mode 100755 index 00000000..6f9cf729 --- /dev/null +++ b/test_mssql.sh @@ -0,0 +1 @@ +go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" \ No newline at end of file diff --git a/xorm_test.go b/xorm_test.go index 66868b65..2e722f86 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -7,6 +7,7 @@ import ( "strings" "testing" + _ "github.com/denisenkom/go-mssqldb" _ "github.com/go-sql-driver/mysql" "github.com/go-xorm/core" _ "github.com/lib/pq" @@ -45,7 +46,10 @@ func createEngine(dbType, connStr string) error { for _, table := range tables { tableNames = append(tableNames, table.Name) } - return testEngine.DropTables(tableNames...) + if err = testEngine.DropTables(tableNames...); err != nil { + return err + } + return nil } func prepareEngine() error { @@ -70,8 +74,8 @@ func TestMain(m *testing.M) { connString = *ptrConnStr } - dbs := strings.Split(*db, ";") - conns := strings.Split(connString, ";") + dbs := strings.Split(*db, "::") + conns := strings.Split(connString, "::") var res int for i := 0; i < len(dbs); i++ {