diff --git a/engine.go b/engine.go index dab313be..b6207706 100644 --- a/engine.go +++ b/engine.go @@ -227,6 +227,8 @@ func (engine *Engine) MapType(t reflect.Type) *Table { col.Default = tags[j+1] case k == "text": col.SQLType = Text + case k == "blob": + col.SQLType = Blob case strings.HasPrefix(k, "int"): if k == "int" { col.SQLType = Int @@ -396,6 +398,24 @@ func (e *Engine) CreateTables(beans ...interface{}) error { return session.Commit() } +func (e *Engine) DropTables(beans ...interface{}) error { + session := e.NewSession() + err := session.Begin() + defer session.Close() + if err != nil { + return err + } + + for _, bean := range beans { + err = session.DropTable(bean) + if err != nil { + session.Rollback() + return err + } + } + return session.Commit() +} + func (e *Engine) CreateAll() error { session := e.NewSession() err := session.Begin() diff --git a/session.go b/session.go index 42993d56..f68257ff 100644 --- a/session.go +++ b/session.go @@ -335,22 +335,14 @@ func (session *Session) CreateTable(bean interface{}) error { defer statement.Init() statement.RefTable = session.Engine.AutoMap(bean) sql := statement.genCreateSQL() - res, err := session.Exec(sql) - if err != nil { - return err - } - - affected, err := res.RowsAffected() - if err != nil { - return err - } - if affected > 0 { + _, err := session.Exec(sql) + if err == nil { sql = statement.genIndexSQL() if len(sql) > 0 { _, err = session.Exec(sql) } } - if err == nil && affected > 0 { + if err == nil { sql = statement.genUniqueSQL() if len(sql) > 0 { _, err = session.Exec(sql) @@ -359,6 +351,15 @@ func (session *Session) CreateTable(bean interface{}) error { return err } +func (session *Session) DropTable(bean interface{}) error { + statement := session.Statement + defer statement.Init() + statement.RefTable = session.Engine.AutoMap(bean) + sql := statement.genDropSQL() + _, err := session.Exec(sql) + return err +} + func (session *Session) Get(bean interface{}) (bool, error) { statement := session.Statement defer statement.Init() diff --git a/statement.go b/statement.go index b734a167..7fc28885 100644 --- a/statement.go +++ b/statement.go @@ -244,7 +244,7 @@ func (statement *Statement) genCreateSQL() string { func (statement *Statement) genIndexSQL() string { var sql string = "" for indexName, cols := range statement.RefTable.Indexes { - sql += fmt.Sprintf("CREATE INDEX IF NOT EXISTS IDX_%v_%v ON %v (%v);", statement.TableName(), indexName, + sql += fmt.Sprintf("CREATE INDEX IDX_%v_%v ON %v (%v);", statement.TableName(), indexName, statement.TableName(), strings.Join(cols, ",")) } return sql @@ -253,7 +253,7 @@ func (statement *Statement) genIndexSQL() string { func (statement *Statement) genUniqueSQL() string { var sql string = "" for indexName, cols := range statement.RefTable.Uniques { - sql += fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS UQE_%v_%v ON %v (%v);", statement.TableName(), indexName, + sql += fmt.Sprintf("CREATE UNIQUE INDEX UQE_%v_%v ON %v (%v);", statement.TableName(), indexName, statement.TableName(), strings.Join(cols, ",")) } return sql diff --git a/testbase.go b/testbase.go index 4e237f0f..6ae6edef 100644 --- a/testbase.go +++ b/testbase.go @@ -41,7 +41,13 @@ type Userdetail struct { } func directCreateTable(engine *Engine, t *testing.T) { - err := engine.CreateTables(&Userinfo{}) + err := engine.DropTables(&Userinfo{}) + if err != nil { + t.Error(err) + return + } + + err = engine.CreateTables(&Userinfo{}) if err != nil { t.Error(err) }