diff --git a/engine.go b/engine.go index 19c748c2..cfa0933e 100644 --- a/engine.go +++ b/engine.go @@ -259,7 +259,6 @@ func (engine *Engine) Close() error { func (engine *Engine) Ping() error { session := engine.NewSession() defer session.Close() - engine.logger.Infof("PING DATABASE %v", engine.DriverName()) return session.Ping() } @@ -1215,6 +1214,9 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { // table, column, index, unique. but will not delete or change anything. // If you change some field, you should change the database manually. func (engine *Engine) Sync(beans ...interface{}) error { + session := engine.NewSession() + defer session.Close() + for _, bean := range beans { v := rValue(bean) tableName := engine.tbName(v) @@ -1223,14 +1225,12 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } - s := engine.NewSession() - defer s.Close() - isExist, err := s.Table(bean).isTableExist(tableName) + isExist, err := session.Table(bean).isTableExist(tableName) if err != nil { return err } if !isExist { - err = engine.CreateTables(bean) + err = session.createTable(bean) if err != nil { return err } @@ -1241,11 +1241,11 @@ func (engine *Engine) Sync(beans ...interface{}) error { }*/ var isEmpty bool if isEmpty { - err = engine.DropTables(bean) + err = session.dropTable(bean) if err != nil { return err } - err = engine.CreateTables(bean) + err = session.createTable(bean) if err != nil { return err } @@ -1256,8 +1256,6 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - session := engine.NewSession() - defer session.Close() if err := session.statement.setRefValue(v); err != nil { return err } @@ -1269,8 +1267,6 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - session := engine.NewSession() - defer session.Close() if err := session.statement.setRefValue(v); err != nil { return err } @@ -1280,8 +1276,6 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - session := engine.NewSession() - defer session.Close() if err := session.statement.setRefValue(v); err != nil { return err } @@ -1297,8 +1291,6 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - session := engine.NewSession() - defer session.Close() if err := session.statement.setRefValue(v); err != nil { return err } @@ -1335,7 +1327,7 @@ func (engine *Engine) CreateTables(beans ...interface{}) error { } for _, bean := range beans { - err = session.CreateTable(bean) + err = session.createTable(bean) if err != nil { session.Rollback() return err @@ -1355,7 +1347,7 @@ func (engine *Engine) DropTables(beans ...interface{}) error { } for _, bean := range beans { - err = session.DropTable(bean) + err = session.dropTable(bean) if err != nil { session.Rollback() return err @@ -1465,10 +1457,10 @@ func (engine *Engine) Rows(bean interface{}) (*Rows, error) { } // Count counts the records. bean's non-empty fields are conditions. -func (engine *Engine) Count(bean interface{}) (int64, error) { +func (engine *Engine) Count(bean ...interface{}) (int64, error) { session := engine.NewSession() defer session.Close() - return session.Count(bean) + return session.Count(bean...) } // Sum sum the records by some column. bean's non-empty fields are conditions. diff --git a/session_schema.go b/session_schema.go index a6a72a40..c1d5088d 100644 --- a/session_schema.go +++ b/session_schema.go @@ -21,36 +21,47 @@ func (session *Session) Ping() error { defer session.Close() } + session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) return session.DB().Ping() } // CreateTable create a table according a bean func (session *Session) CreateTable(bean interface{}) error { + if session.isAutoClose { + defer session.Close() + } + + return session.createTable(bean) +} + +func (session *Session) createTable(bean interface{}) error { + defer session.resetStatement() v := rValue(bean) if err := session.statement.setRefValue(v); err != nil { return err } - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - - return session.createOneTable() + sqlStr := session.statement.genCreateTableSQL() + _, err := session.exec(sqlStr) + return err } // CreateIndexes create indexes func (session *Session) CreateIndexes(bean interface{}) error { + if session.isAutoClose { + defer session.Close() + } + + return session.createIndexes(bean) +} + +func (session *Session) createIndexes(bean interface{}) error { + defer session.resetStatement() v := rValue(bean) if err := session.statement.setRefValue(v); err != nil { return err } - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - sqls := session.statement.genIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) @@ -63,16 +74,19 @@ func (session *Session) CreateIndexes(bean interface{}) error { // CreateUniques create uniques func (session *Session) CreateUniques(bean interface{}) error { + if session.isAutoClose { + defer session.Close() + } + return session.createUniques(bean) +} + +func (session *Session) createUniques(bean interface{}) error { + defer session.resetStatement() v := rValue(bean) if err := session.statement.setRefValue(v); err != nil { return err } - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - sqls := session.statement.genUniqueSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) @@ -83,24 +97,22 @@ func (session *Session) CreateUniques(bean interface{}) error { return nil } -func (session *Session) createOneTable() error { - sqlStr := session.statement.genCreateTableSQL() - _, err := session.exec(sqlStr) - return err -} - // DropIndexes drop indexes func (session *Session) DropIndexes(bean interface{}) error { + if session.isAutoClose { + defer session.Close() + } + + return session.dropIndexes(bean) +} + +func (session *Session) dropIndexes(bean interface{}) error { + defer session.resetStatement() v := rValue(bean) if err := session.statement.setRefValue(v); err != nil { return err } - defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } - sqls := session.statement.genDelIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) @@ -113,6 +125,15 @@ func (session *Session) DropIndexes(bean interface{}) error { // DropTable drop table will drop table if exist, if drop failed, it will return error func (session *Session) DropTable(beanOrTableName interface{}) error { + if session.isAutoClose { + defer session.Close() + } + + return session.dropTable(beanOrTableName) +} + +func (session *Session) dropTable(beanOrTableName interface{}) error { + defer session.resetStatement() tableName, err := session.engine.tableName(beanOrTableName) if err != nil { return err @@ -138,6 +159,10 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { // IsTableExist if a table is exist func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) { + if session.isAutoClose { + defer session.Close() + } + tableName, err := session.engine.tableName(beanOrTableName) if err != nil { return false, err @@ -148,9 +173,6 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) func (session *Session) isTableExist(tableName string) (bool, error) { defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } sqlStr, args := session.engine.dialect.TableCheckSql(tableName) results, err := session.query(sqlStr, args...) return len(results) > 0, err @@ -162,6 +184,9 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { t := v.Type() if t.Kind() == reflect.String { + if session.isAutoClose { + defer session.Close() + } return session.isTableEmpty(bean.(string)) } else if t.Kind() == reflect.Struct { rows, err := session.Count(bean) @@ -172,9 +197,6 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) { defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } var total int64 sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) @@ -193,9 +215,6 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } indexes, err := session.engine.dialect.GetIndexes(tableName) if err != nil { @@ -215,9 +234,6 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo func (session *Session) addColumn(colName string) error { defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } col := session.statement.RefTable.GetColumn(colName) sql, args := session.statement.genAddColumnStr(col) @@ -227,9 +243,7 @@ func (session *Session) addColumn(colName string) error { func (session *Session) addIndex(tableName, idxName string) error { defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } + index := session.statement.RefTable.Indexes[idxName] sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) @@ -239,9 +253,7 @@ func (session *Session) addIndex(tableName, idxName string) error { func (session *Session) addUnique(tableName, uqeName string) error { defer session.resetStatement() - if session.isAutoClose { - defer session.Close() - } + index := session.statement.RefTable.Indexes[uqeName] sqlStr := session.engine.dialect.CreateIndexSql(tableName, index) _, err := session.exec(sqlStr) @@ -252,6 +264,11 @@ func (session *Session) addUnique(tableName, uqeName string) error { func (session *Session) Sync2(beans ...interface{}) error { engine := session.engine + if session.isAutoClose { + session.isAutoClose = false + defer session.Close() + } + tables, err := engine.DBMetas() if err != nil { return err @@ -277,17 +294,17 @@ func (session *Session) Sync2(beans ...interface{}) error { } if oriTable == nil { - err = session.StoreEngine(session.statement.StoreEngine).CreateTable(bean) + err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) if err != nil { return err } - err = session.CreateUniques(bean) + err = session.createUniques(bean) if err != nil { return err } - err = session.CreateIndexes(bean) + err = session.createIndexes(bean) if err != nil { return err } @@ -312,7 +329,7 @@ func (session *Session) Sync2(beans ...interface{}) error { engine.dialect.DBType() == core.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", tbName, col.Name, curType, expectedType) - _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", tbName, col.Name, curType, expectedType) @@ -322,7 +339,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbName, col.Name, oriCol.Length, col.Length) - _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) } } } else { @@ -336,7 +353,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbName, col.Name, oriCol.Length, col.Length) - _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) } } } @@ -349,10 +366,8 @@ func (session *Session) Sync2(beans ...interface{}) error { tbName, col.Name, oriCol.Nullable, col.Nullable) } } else { - session := engine.NewSession() session.statement.RefTable = table session.statement.tableName = tbName - defer session.Close() err = session.addColumn(col.Name) } if err != nil { @@ -376,7 +391,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { sql := engine.dialect.DropIndexSql(tbName, oriIndex) - _, err = engine.Exec(sql) + _, err = session.exec(sql) if err != nil { return err } @@ -392,7 +407,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { sql := engine.dialect.DropIndexSql(tbName, index2) - _, err = engine.Exec(sql) + _, err = session.exec(sql) if err != nil { return err } @@ -401,16 +416,12 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == core.UniqueType { - session := engine.NewSession() session.statement.RefTable = table session.statement.tableName = tbName - defer session.Close() err = session.addUnique(tbName, name) } else if index.Type == core.IndexType { - session := engine.NewSession() session.statement.RefTable = table session.statement.tableName = tbName - defer session.Close() err = session.addIndex(tbName, name) } if err != nil {