diff --git a/base_test.go b/base_test.go index 6816fc70..9a592808 100644 --- a/base_test.go +++ b/base_test.go @@ -231,7 +231,7 @@ func insertMulti(engine *Engine, t *testing.T) { } func insertTwoTable(engine *Engine, t *testing.T) { - userdetail := Userdetail{Id: 1, Intro: "I'm a very beautiful women.", Profile: "sfsaf"} + userdetail := Userdetail{ /*Id: 1, */ Intro: "I'm a very beautiful women.", Profile: "sfsaf"} userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} cnt, err := engine.Insert(&userinfo, &userdetail) @@ -1173,7 +1173,8 @@ func testColTypes(engine *Engine, t *testing.T) { true, - 21, + 0, + //21, } cnt, err := engine.Insert(ac) @@ -1202,6 +1203,10 @@ func testColTypes(engine *Engine, t *testing.T) { newAc.Real = 0 newAc.Float = 0 newAc.Double = 0 + newAc.LongText = "" + newAc.TinyText = "" + newAc.MediumText = "" + newAc.Text = "" cnt, err = engine.Delete(newAc) if err != nil { t.Error(err) @@ -1286,6 +1291,9 @@ func testCustomType(engine *Engine, t *testing.T) { } fmt.Println(i) + i.NameArray = []string{} + i.MSS = map[string]string{} + i.F = 0 has, err := engine.Get(&i) if err != nil { t.Error(err) @@ -1312,6 +1320,8 @@ func testCustomType(engine *Engine, t *testing.T) { fmt.Println(sss) if has { + sss.NameArray = []string{} + sss.MSS = map[string]string{} cnt, err := engine.Delete(&sss) if err != nil { t.Error(err) diff --git a/engine.go b/engine.go index 56a286cb..648923fe 100644 --- a/engine.go +++ b/engine.go @@ -19,11 +19,14 @@ const ( SQLITE = "sqlite3" MYSQL = "mysql" MYMYSQL = "mymysql" + + MSSQL = "mssql" ) // a dialect is a driver's wrapper type dialect interface { Init(DriverName, DataSourceName string) error + DBType() string SqlType(t *Column) string SupportInsertMany() bool QuoteStr() string @@ -248,7 +251,7 @@ func (engine *Engine) DBMetas() ([]*Table, error) { if col, ok := table.Columns[name]; ok { col.Indexes[index.Name] = true } else { - return nil, errors.New("Unkonwn col " + name + " in indexes") + return nil, fmt.Errorf("Unknown col "+name+" in indexes %v", table.Columns) } } } @@ -746,6 +749,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { if err != nil { return err } + fmt.Println("-----", isExist) if !isExist { session := engine.NewSession() session.Statement.RefTable = table diff --git a/mssql.go b/mssql.go new file mode 100644 index 00000000..76188ecd --- /dev/null +++ b/mssql.go @@ -0,0 +1,273 @@ +package xorm + +import ( + //"crypto/tls" + "database/sql" + "errors" + "fmt" + //"regexp" + "strconv" + "strings" + //"time" +) + +type mssql struct { + base + quoteFilter Filter +} + +type mssqlParser struct { +} + +func (p *mssqlParser) parse(driverName, dataSourceName string) (*uri, error) { + return &uri{dbName: "xorm_test", dbType: MSSQL}, nil +} + +func (db *mssql) Init(drivername, uri string) error { + db.quoteFilter = &QuoteFilter{} + return db.base.init(&mssqlParser{}, drivername, uri) +} + +func (db *mssql) SqlType(c *Column) string { + var res string + switch t := c.SQLType.Name; t { + case Bool: + res = TinyInt + case Serial: + c.IsAutoIncrement = true + c.IsPrimaryKey = true + c.Nullable = false + res = Int + case BigSerial: + c.IsAutoIncrement = true + c.IsPrimaryKey = true + c.Nullable = false + res = BigInt + case Bytea, Blob, Binary, TinyBlob, MediumBlob, LongBlob: + res = VarBinary + if c.Length == 0 { + c.Length = 50 + } + case TimeStamp: + res = DateTime + case TimeStampz: + res = "DATETIMEOFFSET" + c.Length = 7 + case MediumInt: + res = Int + case MediumText, TinyText, LongText: + res = Text + case Double: + res = Real + default: + res = t + } + + if res == Int { + return Int + } + + var hasLen1 bool = (c.Length > 0) + var hasLen2 bool = (c.Length2 > 0) + if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } else if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } + return res +} + +func (db *mssql) SupportInsertMany() bool { + return true +} + +func (db *mssql) QuoteStr() string { + return "\"" +} + +func (db *mssql) SupportEngine() bool { + return false +} + +func (db *mssql) AutoIncrStr() string { + return "IDENTITY" +} + +func (db *mssql) SupportCharset() bool { + return false +} + +func (db *mssql) IndexOnTable() bool { + return true +} + +func (db *mssql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { + args := []interface{}{idxName} + sql := "select name from sysindexes where id=object_id('" + tableName + "') and name=?" + return sql, args +} + +func (db *mssql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { + args := []interface{}{tableName, colName} + sql := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` + return sql, args +} + +func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { + args := []interface{}{} + sql := "select * from sysobjects where id = object_id(N'" + tableName + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1" + return sql, args +} + +func (db *mssql) GetColumns(tableName string) ([]string, map[string]*Column, error) { + args := []interface{}{} + s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale +from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id +where a.object_id=object_id('` + tableName + `')` + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, record := range res { + col := new(Column) + col.Indexes = make(map[string]bool) + for name, content := range record { + switch name { + case "name": + col.Name = strings.Trim(string(content), "` ") + case "ctype": + ct := strings.ToUpper(string(content)) + switch ct { + case "DATETIMEOFFSET": + col.SQLType = SQLType{TimeStampz, 0, 0} + default: + if _, ok := sqlTypes[ct]; ok { + col.SQLType = SQLType{ct, 0, 0} + } else { + return nil, nil, errors.New(fmt.Sprintf("unknow colType %v", ct)) + } + } + + case "max_length": + len1, err := strconv.Atoi(strings.TrimSpace(string(content))) + if err != nil { + return nil, nil, err + } + col.Length = len1 + } + } + if col.SQLType.IsText() { + if col.Default != "" { + col.Default = "'" + col.Default + "'" + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + return colSeq, cols, nil +} + +func (db *mssql) GetTables() ([]*Table, error) { + args := []interface{}{} + s := `select name from sysobjects where xtype ='U'` + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "name": + table.Name = strings.Trim(string(content), "` ") + } + } + tables = append(tables, table) + } + return tables, nil +} + +func (db *mssql) GetIndexes(tableName string) (map[string]*Index, error) { + args := []interface{}{tableName} + s := `SELECT +IXS.NAME AS [INDEX_NAME], +C.NAME AS [COLUMN_NAME], +IXS.is_unique AS [IS_UNIQUE], +CASE IXCS.IS_INCLUDED_COLUMN +WHEN 0 THEN 'NONE' +ELSE 'INCLUDED' END AS [IS_INCLUDED_COLUMN] +FROM SYS.INDEXES IXS +INNER JOIN SYS.INDEX_COLUMNS IXCS +ON IXS.OBJECT_ID=IXCS.OBJECT_ID AND IXS.INDEX_ID = IXCS.INDEX_ID +INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID +AND IXCS.COLUMN_ID=C.COLUMN_ID +WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? +` + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + indexes := make(map[string]*Index, 0) + for _, record := range res { + fmt.Println("-----", record, "-----") + var indexType int + var indexName, colName string + for name, content := range record { + switch name { + case "IS_UNIQUE": + i, err := strconv.ParseBool(string(content)) + if err != nil { + return nil, err + } + + fmt.Println(name, string(content), i) + + if i { + indexType = UniqueType + } else { + indexType = IndexType + } + case "INDEX_NAME": + indexName = string(content) + case "COLUMN_NAME": + colName = strings.Trim(string(content), "` ") + } + } + + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + indexName = indexName[5+len(tableName) : len(indexName)] + } + + var index *Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(Index) + index.Type = indexType + index.Name = indexName + indexes[indexName] = index + } + index.AddColumn(colName) + fmt.Print("------end------") + } + return indexes, nil +} diff --git a/mssql_test.go b/mssql_test.go new file mode 100644 index 00000000..858aefc3 --- /dev/null +++ b/mssql_test.go @@ -0,0 +1,73 @@ +package xorm + +import ( + _ "code.google.com/p/odbc" + _ "github.com/mattn/go-adodb" + "testing" +) + +/* +CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET +utf8 COLLATE utf8_general_ci; +*/ + +func newMssqlEngine() (*Engine, error) { + //return NewEngine("adodb", "Provider=SQLOLEDB; Server=127.0.0.1;Database=xorm_test; uid=sa; pwd=1234;") + + return NewEngine("odbc", "driver={SQL Server};Server=127.0.0.1;Database=xorm_test; uid=sa; pwd=1234;") +} + +func TestMssql(t *testing.T) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) +} + +func TestMssqlWithCache(t *testing.T) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + engine.ShowSQL = showTestSql + engine.ShowErr = showTestSql + engine.ShowWarn = showTestSql + engine.ShowDebug = showTestSql + + testAll(engine, t) + testAll2(engine, t) +} + +func BenchmarkMssqlNoCache(t *testing.B) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + //engine.ShowSQL = true + doBenchFind(engine, t) +} + +func BenchmarkMssqlCache(t *testing.B) { + engine, err := newMssqlEngine() + defer engine.Close() + if err != nil { + t.Error(err) + return + } + engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) + doBenchFind(engine, t) +} diff --git a/mysql.go b/mysql.go index bde0186a..6e51780b 100644 --- a/mysql.go +++ b/mysql.go @@ -68,6 +68,10 @@ func (b *base) init(parser parser, drivername, dataSourceName string) (err error return } +func (b *base) DBType() string { + return b.uri.dbType +} + type mysql struct { base net string diff --git a/session.go b/session.go index c8de127d..4eebc7d2 100644 --- a/session.go +++ b/session.go @@ -1003,7 +1003,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1147,8 +1147,8 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo return false, err } - for _, index := range indexes { - //fmt.Println(i, "new:", cols, "-old:", index.Cols) + for i, index := range indexes { + fmt.Println(i, "new:", cols, "-old:", index.Cols, sliceEq(index.Cols, cols), unique, index.Type) if sliceEq(index.Cols, cols) { if unique { return index.Type == UniqueType, nil @@ -1352,6 +1352,7 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st return nil, err } defer rows.Close() + fmt.Println(rows) return rows2maps(rows) } @@ -1635,7 +1636,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var err error // for mysql, when use bit, it returned \x01 if col.SQLType.Name == Bit && - strings.Contains(session.Engine.DriverName, "mysql") { + session.Engine.dialect.DBType() == MYSQL { if len(data) == 1 { x = int64(data[0]) } else { @@ -1646,6 +1647,10 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data x, err = strconv.ParseInt(sdata, 16, 64) } else if strings.HasPrefix(sdata, "0") { x, err = strconv.ParseInt(sdata, 8, 64) + } else if strings.ToLower(sdata) == "true" { + x = 1 + } else if strings.ToLower(sdata) == "false" { + x = 0 } else { x, err = strconv.ParseInt(sdata, 10, 64) } @@ -1685,14 +1690,21 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data if err != nil { x, err = time.Parse("2006-01-02 15:04:05.999999999", sdata) } + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.9999999 Z07:00", sdata) + } } else if len(sdata) == 19 { x, err = time.Parse("2006-01-02 15:04:05", sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { x, err = time.Parse("2006-01-02", sdata) } else if col.SQLType.Name == Time { - if len(sdata) > 8 { - sdata = sdata[len(sdata)-8:] + if strings.Contains(sdata, " ") { + ssd := strings.Split(sdata, " ") + sdata = ssd[1] } + /*if len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + }*/ st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.Parse("2006-01-02 15:04:05", st) } else { @@ -2052,6 +2064,13 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( return fieldValue.String(), nil case reflect.Struct: if fieldType.String() == "time.Time" { + t := fieldValue.Interface().(time.Time) + + if session.Engine.dialect.DBType() == MSSQL { + if t.IsZero() { + return nil, nil + } + } if col.SQLType.Name == Time { //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") s := fieldValue.Interface().(time.Time).Format(time.RFC3339) @@ -2059,6 +2078,11 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( } else if col.SQLType.Name == Date { return fieldValue.Interface().(time.Time).Format("2006-01-02"), nil } else if col.SQLType.Name == TimeStampz { + if session.Engine.dialect.DBType() == MSSQL { + tf := t.Format("2006-01-02T15:04:05.9999999Z07:00") + fmt.Println("====", tf) + return tf, nil + } return fieldValue.Interface().(time.Time).Format(time.RFC3339Nano), nil } return fieldValue.Interface(), nil @@ -2475,7 +2499,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if session.Statement.ColumnStr == "" { colNames, args = buildConditions(session.Engine, table, bean, false, false, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, false, session.Statement.allUseBool, session.Statement.boolColumnMap) } else { colNames, args, err = table.genCols(session, bean, true, true) if err != nil { @@ -2509,7 +2533,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) } var condition = "" @@ -2697,7 +2721,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.autoMap(bean) session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, - false, session.Statement.allUseBool, session.Statement.boolColumnMap) + false, true, session.Statement.allUseBool, session.Statement.boolColumnMap) var condition = "" diff --git a/statement.go b/statement.go index 31718f18..18e3ad99 100644 --- a/statement.go +++ b/statement.go @@ -258,7 +258,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Auto generating conditions according a struct func buildConditions(engine *Engine, table *Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, allUseBool bool, + includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, boolColumnMap map[string]bool) ([]string, []interface{}) { colNames := make([]string, 0) @@ -270,6 +270,14 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, if !includeUpdated && col.IsUpdated { continue } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + // + fmt.Println(engine.dialect.DBType(), Text) + if engine.dialect.DBType() == MSSQL && col.SQLType.Name == Text { + continue + } fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) @@ -361,7 +369,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, if fieldValue == reflect.Zero(fieldType) { continue } - if fieldValue.IsNil() || !fieldValue.IsValid() { + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { continue } @@ -607,7 +615,13 @@ func (statement *Statement) genColumnStr() string { } func (statement *Statement) genCreateTableSQL() string { - sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " (" + var sql string + if statement.Engine.dialect.DBType() == MSSQL { + sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + statement.TableName() + "' ) CREATE TABLE" + } else { + sql = "CREATE TABLE IF NOT EXISTS " + } + sql += statement.Engine.Quote(statement.TableName()) + " (" pkList := []string{} @@ -702,8 +716,13 @@ func (s *Statement) genDelIndexSQL() []string { } func (s *Statement) genDropSQL() string { - sql := "DROP TABLE IF EXISTS " + s.Engine.Quote(s.TableName()) + ";" - return sql + if s.Engine.dialect.DBType() == MSSQL { + return "IF EXISTS (SELECT * FROM sysobjects WHERE id = object_id(N'" + + s.TableName() + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1) " + + "DROP TABLE " + s.Engine.Quote(s.TableName()) + ";" + } else { + return "DROP TABLE IF EXISTS " + s.Engine.Quote(s.TableName()) + ";" + } } // !nashtsai! REVIEW, Statement is a huge struct why is this method not passing *Statement? @@ -712,7 +731,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, - false, statement.allUseBool, statement.boolColumnMap) + false, true, statement.allUseBool, statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -751,7 +770,7 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - statement.allUseBool, statement.boolColumnMap) + true, statement.allUseBool, statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args @@ -797,11 +816,18 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if statement.Start > 0 { - a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) + if statement.Engine.dialect.DBType() != MSSQL { + if statement.Start > 0 { + a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) + } else if statement.LimitN > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) + } + } else { + /*SELECT * FROM ( + SELECT *, ROW_NUMBER() OVER (ORDER BY id desc) as row FROM "userinfo" + ) a WHERE row > [start] and row <= [start+limit] order by id desc*/ } + return } diff --git a/xorm.go b/xorm.go index 23a4a69b..10d6c20c 100644 --- a/xorm.go +++ b/xorm.go @@ -34,6 +34,9 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { engine.Filters = append(engine.Filters, &QuoteFilter{}) } else if driverName == MYMYSQL { engine.dialect = &mymysql{} + } else if driverName == "odbc" { + engine.dialect = &mssql{quoteFilter: &QuoteFilter{}} + engine.Filters = append(engine.Filters, &QuoteFilter{}) } else { return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) }