bug fix
This commit is contained in:
parent
5adfc8e923
commit
dd06d9a4cf
20
engine.go
20
engine.go
|
@ -227,6 +227,8 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
|
||||||
col.Default = tags[j+1]
|
col.Default = tags[j+1]
|
||||||
case k == "text":
|
case k == "text":
|
||||||
col.SQLType = Text
|
col.SQLType = Text
|
||||||
|
case k == "blob":
|
||||||
|
col.SQLType = Blob
|
||||||
case strings.HasPrefix(k, "int"):
|
case strings.HasPrefix(k, "int"):
|
||||||
if k == "int" {
|
if k == "int" {
|
||||||
col.SQLType = Int
|
col.SQLType = Int
|
||||||
|
@ -396,6 +398,24 @@ func (e *Engine) CreateTables(beans ...interface{}) error {
|
||||||
return session.Commit()
|
return session.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) DropTables(beans ...interface{}) error {
|
||||||
|
session := e.NewSession()
|
||||||
|
err := session.Begin()
|
||||||
|
defer session.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bean := range beans {
|
||||||
|
err = session.DropTable(bean)
|
||||||
|
if err != nil {
|
||||||
|
session.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return session.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) CreateAll() error {
|
func (e *Engine) CreateAll() error {
|
||||||
session := e.NewSession()
|
session := e.NewSession()
|
||||||
err := session.Begin()
|
err := session.Begin()
|
||||||
|
|
23
session.go
23
session.go
|
@ -335,22 +335,14 @@ func (session *Session) CreateTable(bean interface{}) error {
|
||||||
defer statement.Init()
|
defer statement.Init()
|
||||||
statement.RefTable = session.Engine.AutoMap(bean)
|
statement.RefTable = session.Engine.AutoMap(bean)
|
||||||
sql := statement.genCreateSQL()
|
sql := statement.genCreateSQL()
|
||||||
res, err := session.Exec(sql)
|
_, err := session.Exec(sql)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if affected > 0 {
|
|
||||||
sql = statement.genIndexSQL()
|
sql = statement.genIndexSQL()
|
||||||
if len(sql) > 0 {
|
if len(sql) > 0 {
|
||||||
_, err = session.Exec(sql)
|
_, err = session.Exec(sql)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err == nil && affected > 0 {
|
if err == nil {
|
||||||
sql = statement.genUniqueSQL()
|
sql = statement.genUniqueSQL()
|
||||||
if len(sql) > 0 {
|
if len(sql) > 0 {
|
||||||
_, err = session.Exec(sql)
|
_, err = session.Exec(sql)
|
||||||
|
@ -359,6 +351,15 @@ func (session *Session) CreateTable(bean interface{}) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (session *Session) DropTable(bean interface{}) error {
|
||||||
|
statement := session.Statement
|
||||||
|
defer statement.Init()
|
||||||
|
statement.RefTable = session.Engine.AutoMap(bean)
|
||||||
|
sql := statement.genDropSQL()
|
||||||
|
_, err := session.Exec(sql)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (session *Session) Get(bean interface{}) (bool, error) {
|
func (session *Session) Get(bean interface{}) (bool, error) {
|
||||||
statement := session.Statement
|
statement := session.Statement
|
||||||
defer statement.Init()
|
defer statement.Init()
|
||||||
|
|
|
@ -244,7 +244,7 @@ func (statement *Statement) genCreateSQL() string {
|
||||||
func (statement *Statement) genIndexSQL() string {
|
func (statement *Statement) genIndexSQL() string {
|
||||||
var sql string = ""
|
var sql string = ""
|
||||||
for indexName, cols := range statement.RefTable.Indexes {
|
for indexName, cols := range statement.RefTable.Indexes {
|
||||||
sql += fmt.Sprintf("CREATE INDEX IF NOT EXISTS IDX_%v_%v ON %v (%v);", statement.TableName(), indexName,
|
sql += fmt.Sprintf("CREATE INDEX IDX_%v_%v ON %v (%v);", statement.TableName(), indexName,
|
||||||
statement.TableName(), strings.Join(cols, ","))
|
statement.TableName(), strings.Join(cols, ","))
|
||||||
}
|
}
|
||||||
return sql
|
return sql
|
||||||
|
@ -253,7 +253,7 @@ func (statement *Statement) genIndexSQL() string {
|
||||||
func (statement *Statement) genUniqueSQL() string {
|
func (statement *Statement) genUniqueSQL() string {
|
||||||
var sql string = ""
|
var sql string = ""
|
||||||
for indexName, cols := range statement.RefTable.Uniques {
|
for indexName, cols := range statement.RefTable.Uniques {
|
||||||
sql += fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS UQE_%v_%v ON %v (%v);", statement.TableName(), indexName,
|
sql += fmt.Sprintf("CREATE UNIQUE INDEX UQE_%v_%v ON %v (%v);", statement.TableName(), indexName,
|
||||||
statement.TableName(), strings.Join(cols, ","))
|
statement.TableName(), strings.Join(cols, ","))
|
||||||
}
|
}
|
||||||
return sql
|
return sql
|
||||||
|
|
|
@ -41,7 +41,13 @@ type Userdetail struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func directCreateTable(engine *Engine, t *testing.T) {
|
func directCreateTable(engine *Engine, t *testing.T) {
|
||||||
err := engine.CreateTables(&Userinfo{})
|
err := engine.DropTables(&Userinfo{})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = engine.CreateTables(&Userinfo{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue