This commit is contained in:
Lunny Xiao 2013-07-27 21:47:22 +08:00
parent 5adfc8e923
commit dd06d9a4cf
4 changed files with 41 additions and 14 deletions

View File

@ -227,6 +227,8 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
col.Default = tags[j+1]
case k == "text":
col.SQLType = Text
case k == "blob":
col.SQLType = Blob
case strings.HasPrefix(k, "int"):
if k == "int" {
col.SQLType = Int
@ -396,6 +398,24 @@ func (e *Engine) CreateTables(beans ...interface{}) error {
return session.Commit()
}
func (e *Engine) DropTables(beans ...interface{}) error {
session := e.NewSession()
err := session.Begin()
defer session.Close()
if err != nil {
return err
}
for _, bean := range beans {
err = session.DropTable(bean)
if err != nil {
session.Rollback()
return err
}
}
return session.Commit()
}
func (e *Engine) CreateAll() error {
session := e.NewSession()
err := session.Begin()

View File

@ -335,22 +335,14 @@ func (session *Session) CreateTable(bean interface{}) error {
defer statement.Init()
statement.RefTable = session.Engine.AutoMap(bean)
sql := statement.genCreateSQL()
res, err := session.Exec(sql)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected > 0 {
_, err := session.Exec(sql)
if err == nil {
sql = statement.genIndexSQL()
if len(sql) > 0 {
_, err = session.Exec(sql)
}
}
if err == nil && affected > 0 {
if err == nil {
sql = statement.genUniqueSQL()
if len(sql) > 0 {
_, err = session.Exec(sql)
@ -359,6 +351,15 @@ func (session *Session) CreateTable(bean interface{}) error {
return err
}
func (session *Session) DropTable(bean interface{}) error {
statement := session.Statement
defer statement.Init()
statement.RefTable = session.Engine.AutoMap(bean)
sql := statement.genDropSQL()
_, err := session.Exec(sql)
return err
}
func (session *Session) Get(bean interface{}) (bool, error) {
statement := session.Statement
defer statement.Init()

View File

@ -244,7 +244,7 @@ func (statement *Statement) genCreateSQL() string {
func (statement *Statement) genIndexSQL() string {
var sql string = ""
for indexName, cols := range statement.RefTable.Indexes {
sql += fmt.Sprintf("CREATE INDEX IF NOT EXISTS IDX_%v_%v ON %v (%v);", statement.TableName(), indexName,
sql += fmt.Sprintf("CREATE INDEX IDX_%v_%v ON %v (%v);", statement.TableName(), indexName,
statement.TableName(), strings.Join(cols, ","))
}
return sql
@ -253,7 +253,7 @@ func (statement *Statement) genIndexSQL() string {
func (statement *Statement) genUniqueSQL() string {
var sql string = ""
for indexName, cols := range statement.RefTable.Uniques {
sql += fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS UQE_%v_%v ON %v (%v);", statement.TableName(), indexName,
sql += fmt.Sprintf("CREATE UNIQUE INDEX UQE_%v_%v ON %v (%v);", statement.TableName(), indexName,
statement.TableName(), strings.Join(cols, ","))
}
return sql

View File

@ -41,7 +41,13 @@ type Userdetail struct {
}
func directCreateTable(engine *Engine, t *testing.T) {
err := engine.CreateTables(&Userinfo{})
err := engine.DropTables(&Userinfo{})
if err != nil {
t.Error(err)
return
}
err = engine.CreateTables(&Userinfo{})
if err != nil {
t.Error(err)
}