diff --git a/engine.go b/engine.go index 23038e1d..205ed873 100644 --- a/engine.go +++ b/engine.go @@ -34,6 +34,8 @@ type dialect interface { SqlType(t *Column) string SupportInsertMany() bool QuoteStr() string + RollBackStr() string + DropTableSql(tableName string) string AutoIncrStr() string SupportEngine() bool SupportCharset() bool @@ -449,6 +451,18 @@ func (engine *Engine) newTable() *Table { return table } +func addIndex(indexName string, table *Table, col *Column, indexType int) { + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = true + } else { + index := NewIndex(indexName, indexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = true + } +} + func (engine *Engine) mapType(t reflect.Type) *Table { table := engine.newTable() table.Name = engine.tableMapper.Obj2Table(t.Name()) @@ -484,8 +498,9 @@ func (engine *Engine) mapType(t reflect.Type) *Table { table.PrimaryKeys = parentTable.PrimaryKeys continue } - var indexType int - var indexName string + + indexNames := make(map[string]int) + var isIndex, isUnique bool var preKey string for j, key := range tags { k := strings.ToUpper(key) @@ -521,15 +536,15 @@ func (engine *Engine) mapType(t reflect.Type) *Table { case k == "UPDATED": col.IsUpdated = true case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): - indexType = IndexType - indexName = k[len("INDEX")+1 : len(k)-1] + indexName := k[len("INDEX")+1 : len(k)-1] + indexNames[indexName] = IndexType case k == "INDEX": - indexType = IndexType + isIndex = true case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): - indexName = k[len("UNIQUE")+1 : len(k)-1] - indexType = UniqueType + indexName := k[len("UNIQUE")+1 : len(k)-1] + indexNames[indexName] = UniqueType case k == "UNIQUE": - indexType = UniqueType + isUnique = true case k == "NOTNULL": col.Nullable = false case k == "NOT": @@ -584,32 +599,15 @@ func (engine *Engine) mapType(t reflect.Type) *Table { if col.Name == "" { col.Name = engine.columnMapper.Obj2Table(t.Field(i).Name) } - if indexType == IndexType { - if indexName == "" { - indexName = col.Name - } - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = true - } else { - index := NewIndex(indexName, IndexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = true - } - } else if indexType == UniqueType { - if indexName == "" { - indexName = col.Name - } - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = true - } else { - index := NewIndex(indexName, UniqueType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = true - } + + if isUnique { + indexNames[col.Name] = UniqueType + } else if isIndex { + indexNames[col.Name] = IndexType + } + + for indexName, indexType := range indexNames { + addIndex(indexName, table, col, indexType) } } } else { @@ -810,6 +808,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } } else if index.Type == IndexType { + fmt.Println("index:", table.Name, name, index) isExist, err := session.isIndexExist2(table.Name, index.Cols, false) if err != nil { return err diff --git a/mssql.go b/mssql.go index 6e9776d2..54c93e71 100644 --- a/mssql.go +++ b/mssql.go @@ -108,6 +108,12 @@ func (db *mssql) AutoIncrStr() string { return "IDENTITY" } +func (db *mssql) DropTableSql(tableName string) string { + return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ + "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ + "DROP TABLE \"%s\"", tableName, tableName) +} + func (db *mssql) SupportCharset() bool { return false } @@ -187,7 +193,7 @@ where a.object_id=object_id('` + tableName + `')` if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" - }else{ + } else { if col.DefaultIsEmpty { col.Default = "''" } diff --git a/mysql.go b/mysql.go index 23b53641..8d0cfaa3 100644 --- a/mysql.go +++ b/mysql.go @@ -89,6 +89,14 @@ func (b *base) DBType() string { return b.uri.dbType } +func (db *base) RollBackStr() string { + return "ROLL BACK" +} + +func (db *base) DropTableSql(tableName string) string { + return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName) +} + type mysql struct { base net string diff --git a/ql.go b/ql.go deleted file mode 100644 index 8f26f3b5..00000000 --- a/ql.go +++ /dev/null @@ -1,232 +0,0 @@ -package xorm - -import ( - "database/sql" - "strings" -) - -type ql struct { - base -} - -type qlParser struct { -} - -func (p *qlParser) parse(driverName, dataSourceName string) (*uri, error) { - return &uri{dbType: QL, dbName: dataSourceName}, nil -} - -func (db *ql) Init(drivername, dataSourceName string) error { - return db.base.init(&qlParser{}, drivername, dataSourceName) -} - -func (db *ql) SqlType(c *Column) string { - switch t := c.SQLType.Name; t { - case Date, DateTime, TimeStamp, Time: - return Numeric - case TimeStampz: - return Text - case Char, Varchar, TinyText, Text, MediumText, LongText: - return Text - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: - return Integer - case Float, Double, Real: - return Real - case Decimal, Numeric: - return Numeric - case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: - return Blob - case Serial, BigSerial: - c.IsPrimaryKey = true - c.IsAutoIncrement = true - c.Nullable = false - return Integer - default: - return t - } -} - -func (db *ql) SupportInsertMany() bool { - return true -} - -func (db *ql) QuoteStr() string { - return "" -} - -func (db *ql) AutoIncrStr() string { - return "AUTOINCREMENT" -} - -func (db *ql) SupportEngine() bool { - return false -} - -func (db *ql) SupportCharset() bool { - return false -} - -func (db *ql) IndexOnTable() bool { - return false -} - -func (db *ql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{idxName} - return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args -} - -func (db *ql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args -} - -func (db *ql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName} - sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - return sql, args -} - -func (db *ql) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{tableName} - s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - cnn, err := sql.Open(db.driverName, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } - - var sql string - for _, record := range res { - for name, content := range record { - if name == "sql" { - sql = string(content) - } - } - } - - nStart := strings.Index(sql, "(") - nEnd := strings.Index(sql, ")") - colCreates := strings.Split(sql[nStart+1:nEnd], ",") - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, colStr := range colCreates { - fields := strings.Fields(strings.TrimSpace(colStr)) - col := new(Column) - col.Indexes = make(map[string]bool) - col.Nullable = true - for idx, field := range fields { - if idx == 0 { - col.Name = strings.Trim(field, "`[] ") - continue - } else if idx == 1 { - col.SQLType = SQLType{field, 0, 0} - } - switch field { - case "PRIMARY": - col.IsPrimaryKey = true - case "AUTOINCREMENT": - col.IsAutoIncrement = true - case "NULL": - if fields[idx-1] == "NOT" { - col.Nullable = false - } else { - col.Nullable = true - } - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } - return colSeq, cols, nil -} - -func (db *ql) GetTables() ([]*Table, error) { - args := []interface{}{} - s := "SELECT name FROM sqlite_master WHERE type='table'" - - cnn, err := sql.Open(db.driverName, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } - - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "name": - table.Name = string(content) - } - } - if table.Name == "sqlite_sequence" { - continue - } - tables = append(tables, table) - } - return tables, nil -} - -func (db *ql) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{tableName} - s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - cnn, err := sql.Open(db.driverName, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } - - indexes := make(map[string]*Index, 0) - for _, record := range res { - index := new(Index) - sql := string(record["sql"]) - - if sql == "" { - continue - } - - nNStart := strings.Index(sql, "INDEX") - nNEnd := strings.Index(sql, "ON") - if nNStart == -1 || nNEnd == -1 { - continue - } - - indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") - //fmt.Println(indexName) - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - index.Name = indexName[5+len(tableName) : len(indexName)] - } else { - index.Name = indexName - } - - if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { - index.Type = UniqueType - } else { - index.Type = IndexType - } - - nStart := strings.Index(sql, "(") - nEnd := strings.Index(sql, ")") - colIndexes := strings.Split(sql[nStart+1:nEnd], ",") - - index.Cols = make([]string, 0) - for _, col := range colIndexes { - index.Cols = append(index.Cols, strings.Trim(col, "` []")) - } - indexes[index.Name] = index - } - - return indexes, nil -} diff --git a/ql_test.go b/ql_test.go deleted file mode 100644 index 46d0104d..00000000 --- a/ql_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package xorm - -import ( - "database/sql" - "os" - "testing" - - _ "github.com/mattn/ql-driver" -) - -func newQlEngine() (*Engine, error) { - os.Remove("./ql.db") - return NewEngine("ql", "./ql.db") -} - -func newQlDriverDB() (*sql.DB, error) { - os.Remove("./ql.db") - return sql.Open("ql", "./ql.db") -} - -func TestQl(t *testing.T) { - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - engine.ShowSQL = showTestSql - engine.ShowErr = showTestSql - engine.ShowWarn = showTestSql - engine.ShowDebug = showTestSql - - testAll(engine, t) - testAll2(engine, t) - testAll3(engine, t) -} - -func TestQlWithCache(t *testing.T) { - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - engine.ShowSQL = showTestSql - engine.ShowErr = showTestSql - engine.ShowWarn = showTestSql - engine.ShowDebug = showTestSql - - testAll(engine, t) - testAll2(engine, t) -} - -const ( - createTableQl = "CREATE TABLE IF NOT EXISTS `big_struct` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, `title` TEXT NULL, `age` TEXT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL);" - dropTableQl = "DROP TABLE IF EXISTS `big_struct`;" -) - -func BenchmarkQlDriverInsert(t *testing.B) { - doBenchDriver(newQlDriverDB, createTableQl, dropTableQl, - doBenchDriverInsert, t) -} - -func BenchmarkQlDriverFind(t *testing.B) { - doBenchDriver(newQlDriverDB, createTableQl, dropTableQl, - doBenchDriverFind, t) -} - -func BenchmarkQlNoCacheInsert(t *testing.B) { - t.StopTimer() - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - //engine.ShowSQL = true - doBenchInsert(engine, t) -} - -func BenchmarkQlNoCacheFind(t *testing.B) { - t.StopTimer() - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - //engine.ShowSQL = true - doBenchFind(engine, t) -} - -func BenchmarkQlNoCacheFindPtr(t *testing.B) { - t.StopTimer() - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - //engine.ShowSQL = true - doBenchFindPtr(engine, t) -} - -func BenchmarkQlCacheInsert(t *testing.B) { - t.StopTimer() - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - doBenchInsert(engine, t) -} - -func BenchmarkQlCacheFind(t *testing.B) { - t.StopTimer() - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - doBenchFind(engine, t) -} - -func BenchmarkQlCacheFindPtr(t *testing.B) { - t.StopTimer() - engine, err := newQlEngine() - defer engine.Close() - if err != nil { - t.Error(err) - return - } - engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000)) - doBenchFindPtr(engine, t) -} diff --git a/session.go b/session.go index 731027ce..10bba1bb 100644 --- a/session.go +++ b/session.go @@ -286,7 +286,7 @@ func (session *Session) Begin() error { // When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.LogSQL("ROLL BACK") + session.Engine.LogSQL(session.Engine.dialect.RollBackStr()) session.IsCommitedOrRollbacked = true return session.Tx.Rollback() } diff --git a/statement.go b/statement.go index 80c5dccb..e4ea1a2c 100644 --- a/statement.go +++ b/statement.go @@ -25,7 +25,7 @@ type Statement struct { HavingStr string ColumnStr string columnMap map[string]bool - useAllCols bool + useAllCols bool OmitStr string ConditionStr string AltTableName string @@ -240,7 +240,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Auto generating conditions according a struct func buildConditions(engine *Engine, table *Table, bean interface{}, - includeVersion bool, includeUpdated bool, includeNil bool, + includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, boolColumnMap map[string]bool) ([]string, []interface{}) { @@ -712,13 +712,7 @@ func (s *Statement) genDelIndexSQL() []string { } func (s *Statement) genDropSQL() string { - if s.Engine.dialect.DBType() == MSSQL { - return "IF EXISTS (SELECT * FROM sysobjects WHERE id = object_id(N'" + - s.TableName() + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1) " + - "DROP TABLE " + s.Engine.Quote(s.TableName()) + ";" - } else { - return "DROP TABLE IF EXISTS " + s.Engine.Quote(s.TableName()) + ";" - } + return s.Engine.dialect.DropTableSql(s.TableName()) + ";" } func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { @@ -766,7 +760,7 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - true, statement.allUseBool, statement.useAllCols,statement.boolColumnMap) + true, statement.allUseBool, statement.useAllCols, statement.boolColumnMap) statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args