add lib/pq support; bool on struct will not be as condition;

This commit is contained in:
Lunny Xiao 2013-11-06 15:36:38 +08:00
parent 0c541edf3f
commit 31de5d612e
7 changed files with 278 additions and 158 deletions

View File

@ -8,7 +8,7 @@ Xorm is a simple and powerful ORM for Go. It makes dabatabse operating simple.
## Discuss ## Discuss
Message me on [G+](https://plus.google.com/u/0/106406879480103142585) Please visit [xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm)
## Drivers Support ## Drivers Support
@ -20,6 +20,8 @@ Drivers for Go's sql package which currently support database/sql includes:
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/bylevel/pq](https://github.com/lib/pq)
* Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq) * Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq)
@ -40,19 +42,23 @@ Drivers for Go's sql package which currently support database/sql includes:
## Features ## Features
* Struct<->Table Mapping Supports, both name mapping and filed tags mapping * Struct<->Table Mapping Supports, both name mapping and filed tag mapping
* Database Transaction Support * Database Transaction Support
* Both ORM and SQL Operation Support * Both ORM and SQL Operation Support
* Simply usage * Simply chainable usage
* Support Id, In, Where, Limit, Join, Having, Sql functions and sturct as query conditions * Support Id, In, Where, Limit, Join, Having, Sql functions and sturct as query conditions
* Support simple cascade load just like Hibernate for Java * Cache Support
* Code generator support, See [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md) * Reverse Tool Support
* Simple cascade load support
* Database Reverse Tool support, See [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md)
## Installing xorm ## Installing xorm

View File

@ -97,6 +97,9 @@ func insert(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
if user.Uid <= 0 {
t.Error(errors.New("not return id error"))
}
} }
func testQuery(engine *Engine, t *testing.T) { func testQuery(engine *Engine, t *testing.T) {
@ -149,6 +152,9 @@ func insertAutoIncr(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
if user.Uid <= 0 {
t.Error(errors.New("not return id error"))
}
} }
func insertMulti(engine *Engine, t *testing.T) { func insertMulti(engine *Engine, t *testing.T) {
@ -159,11 +165,14 @@ func insertMulti(engine *Engine, t *testing.T) {
{Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
} }
_, err := engine.Insert(&users) id, err := engine.Insert(&users)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
if id <= 0 {
t.Error(errors.New("not return id error"))
}
users2 := []*Userinfo{ users2 := []*Userinfo{
&Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
@ -172,11 +181,15 @@ func insertMulti(engine *Engine, t *testing.T) {
&Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, &Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
} }
_, err = engine.Insert(&users2) id, err = engine.Insert(&users2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
if id <= 0 {
t.Error(errors.New("not return id error"))
}
} }
func insertTwoTable(engine *Engine, t *testing.T) { func insertTwoTable(engine *Engine, t *testing.T) {
@ -245,6 +258,18 @@ func testdelete(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
_, err = engine.Id(2).Get(&user)
if err != nil {
t.Error(err)
panic(err)
}
_, err = engine.Delete(&user)
if err != nil {
t.Error(err)
panic(err)
}
} }
func get(engine *Engine, t *testing.T) { func get(engine *Engine, t *testing.T) {
@ -1156,6 +1181,66 @@ func testStrangeName(engine *Engine, t *testing.T) {
} }
} }
type Version struct {
Id int64
Name string
Ver int `xorm:"version"`
}
func testVersion(engine *Engine, t *testing.T) {
err := engine.DropTables(new(Version))
if err != nil {
t.Error(err)
return
}
err = engine.CreateTables(new(Version))
if err != nil {
t.Error(err)
return
}
ver := &Version{Name: "sfsfdsfds"}
_, err = engine.Cols("name").Insert(ver)
if err != nil {
t.Error(err)
return
}
newVer := new(Version)
has, err := engine.Id(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
return
}
if !has {
t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id)))
return
}
newVer.Name = "-------"
_, err = engine.Id(ver.Id).Update(newVer, &Version{Ver: newVer.Ver})
if err != nil {
t.Error(err)
return
}
has, err = engine.Id(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
return
}
fmt.Println(ver)
newVer.Name = "-------"
_, err = engine.Id(ver.Id).Update(newVer, &Version{Ver: newVer.Ver})
if err != nil {
t.Error(err)
return
}
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -1240,6 +1325,8 @@ func testAll2(engine *Engine, t *testing.T) {
testIterate(engine, t) testIterate(engine, t)
fmt.Println("-------------- testStrangeName --------------") fmt.Println("-------------- testStrangeName --------------")
testStrangeName(engine, t) testStrangeName(engine, t)
fmt.Println("-------------- testVersion --------------")
testVersion(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
} }

View File

@ -138,38 +138,48 @@ func (engine *Engine) Close() error {
return engine.Pool.Close(engine) return engine.Pool.Close(engine)
} }
// Test if database is alive. // Test method is deprecated, use Ping() method.
func (engine *Engine) Test() error { func (engine *Engine) Test() error {
return engine.Ping()
}
// Ping tests if database is alive.
func (engine *Engine) Ping() error {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
engine.LogSQL("PING DATABASE", engine.DriverName) engine.LogSQL("PING DATABASE", engine.DriverName)
return session.Ping() return session.Ping()
} }
// logging sql
func (engine *Engine) LogSQL(contents ...interface{}) { func (engine *Engine) LogSQL(contents ...interface{}) {
if engine.ShowSQL { if engine.ShowSQL {
io.WriteString(engine.Logger, fmt.Sprintln(contents...)) io.WriteString(engine.Logger, fmt.Sprintln(contents...))
} }
} }
// logging error
func (engine *Engine) LogError(contents ...interface{}) { func (engine *Engine) LogError(contents ...interface{}) {
if engine.ShowErr { if engine.ShowErr {
io.WriteString(engine.Logger, fmt.Sprintln(contents...)) io.WriteString(engine.Logger, fmt.Sprintln(contents...))
} }
} }
// logging debug
func (engine *Engine) LogDebug(contents ...interface{}) { func (engine *Engine) LogDebug(contents ...interface{}) {
if engine.ShowDebug { if engine.ShowDebug {
io.WriteString(engine.Logger, fmt.Sprintln(contents...)) io.WriteString(engine.Logger, fmt.Sprintln(contents...))
} }
} }
// logging warn
func (engine *Engine) LogWarn(contents ...interface{}) { func (engine *Engine) LogWarn(contents ...interface{}) {
if engine.ShowWarn { if engine.ShowWarn {
io.WriteString(engine.Logger, fmt.Sprintln(contents...)) io.WriteString(engine.Logger, fmt.Sprintln(contents...))
} }
} }
// execute sql
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
@ -182,6 +192,7 @@ func (engine *Engine) NoAutoTime() *Session {
return session.NoAutoTime() return session.NoAutoTime()
} }
// retrieve all tables, columns, indexes' informations from database.
func (engine *Engine) DBMetas() ([]*Table, error) { func (engine *Engine) DBMetas() ([]*Table, error) {
tables, err := engine.dialect.GetTables() tables, err := engine.dialect.GetTables()
if err != nil { if err != nil {
@ -215,30 +226,35 @@ func (engine *Engine) DBMetas() ([]*Table, error) {
return tables, nil return tables, nil
} }
// use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
return session.Cascade(trueOrFalse...) return session.Cascade(trueOrFalse...)
} }
// Where method provide a condition query
func (engine *Engine) Where(querystring string, args ...interface{}) *Session { func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
return session.Where(querystring, args...) return session.Where(querystring, args...)
} }
// Id mehtod provoide a condition as (id) = ?
func (engine *Engine) Id(id int64) *Session { func (engine *Engine) Id(id int64) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
return session.Id(id) return session.Id(id)
} }
// set charset when create table, only support mysql now
func (engine *Engine) Charset(charset string) *Session { func (engine *Engine) Charset(charset string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
return session.Charset(charset) return session.Charset(charset)
} }
// set store engine when create table, only support mysql now
func (engine *Engine) StoreEngine(storeEngine string) *Session { func (engine *Engine) StoreEngine(storeEngine string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
@ -257,12 +273,6 @@ func (engine *Engine) Omit(columns ...string) *Session {
return session.Omit(columns...) return session.Omit(columns...)
} }
/*func (engine *Engine) Trans(t string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Trans(t)
}*/
func (engine *Engine) In(column string, args ...interface{}) *Session { func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
@ -398,10 +408,11 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
col.Default = tags[j+1] col.Default = tags[j+1]
case k == "CREATED": case k == "CREATED":
col.IsCreated = true col.IsCreated = true
case k == "VERSION":
col.IsVersion = true
col.Default = "1"
case k == "UPDATED": case k == "UPDATED":
col.IsUpdated = true col.IsUpdated = true
/*case strings.HasPrefix(k, "--"):
col.Comment = k[2:len(k)]*/
case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
indexType = IndexType indexType = IndexType
indexName = k[len("INDEX")+1 : len(k)-1] indexName = k[len("INDEX")+1 : len(k)-1]
@ -487,7 +498,7 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
sqlType := Type2SQLType(fieldType) sqlType := Type2SQLType(fieldType)
col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false,
TWOSIDES, false, false, false} TWOSIDES, false, false, false, false}
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
col.Nullable = false col.Nullable = false

View File

@ -2,7 +2,8 @@ package xorm
import ( import (
"fmt" "fmt"
_ "github.com/bylevel/pq" //_ "github.com/bylevel/pq"
_ "github.com/lib/pq"
"testing" "testing"
) )

View File

@ -850,7 +850,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
session.Statement.RefTable = table session.Statement.RefTable = table
if len(condiBean) > 0 { if len(condiBean) > 0 {
colNames, args := buildConditions(session.Engine, table, condiBean[0]) colNames, args := buildConditions(session.Engine, table, condiBean[0], true)
session.Statement.ConditionStr = strings.Join(colNames, " and ") session.Statement.ConditionStr = strings.Join(colNames, " and ")
session.Statement.BeanArgs = args session.Statement.BeanArgs = args
} }
@ -1185,77 +1185,6 @@ func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice
session.Engine.LogSQL(paramStr) session.Engine.LogSQL(paramStr)
return query(session.Db, sql, paramStr...) return query(session.Db, sql, paramStr...)
/*s, err := session.Db.Prepare(sql)
if err != nil {
return nil, err
}
defer s.Close()
res, err := s.Query(paramStr...)
if err != nil {
return nil, err
}
defer res.Close()
fields, err := res.Columns()
if err != nil {
return nil, err
}
for res.Next() {
result := make(map[string][]byte)
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers = append(scanResultContainers, &scanResultContainer)
}
if err := res.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
aa := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
var str string
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
result[key] = []byte(str)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str)
case reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
result[key] = rawValue.Interface().([]byte)
default:
session.Engine.LogError("Unsupported type")
}
case reflect.String:
str = vv.String()
result[key] = []byte(str)
//时间类型
case reflect.Struct:
if aa.String() == "time.Time" {
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str)
} else {
session.Engine.LogError("Unsupported struct type")
}
default:
session.Engine.LogError("Unsupported type")
}
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil*/
} }
func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
@ -1315,7 +1244,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) { func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice { if sliceValue.Kind() != reflect.Slice {
return -1, errors.New("needs a pointer to a slice") return 0, errors.New("needs a pointer to a slice")
} }
bean := sliceValue.Index(0).Interface() bean := sliceValue.Index(0).Interface()
@ -1393,7 +1322,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", ")) colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
} }
statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);", statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
@ -1402,22 +1331,50 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
strings.Join(colMultiPlaces, "),(")) strings.Join(colMultiPlaces, "),("))
res, err := session.exec(statement, args...) if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" {
if err != nil { res, err := session.exec(statement, args...)
return -1, err if err != nil {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
if table.PrimaryKey != "" {
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
} else {
return 0, err
}
} else {
statement += " RETURNING (id)"
res, err := session.query(statement, args...)
if err != nil {
return 0, err
}
if len(res) < 1 {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
idByte := res[0][table.PrimaryKey]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil {
return 0, err
}
return id, nil
} }
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
id, err := res.LastInsertId()
if err != nil {
return -1, err
}
return id, nil
} }
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
@ -1644,7 +1601,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
colPlaces := strings.Repeat("?, ", len(colNames)) colPlaces := strings.Repeat("?, ", len(colNames))
colPlaces = colPlaces[0 : len(colPlaces)-2] colPlaces = colPlaces[0 : len(colPlaces)-2]
sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);", sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
session.Statement.TableName(), session.Statement.TableName(),
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
@ -1653,40 +1610,80 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.Engine.QuoteStr(), session.Engine.QuoteStr(),
colPlaces) colPlaces)
res, err := session.exec(sql, args...) // for postgres, many of them didn't implement lastInsertId, so we should
if err != nil { // implemented it ourself.
return 0, err if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" {
} res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache { if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.Statement.TableName())
} }
if table.PrimaryKey == "" { if table.PrimaryKey == "" {
return 0, nil return 0, nil
} }
var id int64 = 0 var id int64 = 0
pkValue := table.PKColumn().ValueOf(bean) pkValue := table.PKColumn().ValueOf(bean)
if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() {
return 0, nil return 0, nil
} }
id, err = res.LastInsertId() id, err = res.LastInsertId()
if err != nil || id <= 0 { if err != nil || id <= 0 {
return 0, err return 0, err
} }
var v interface{} = id var v interface{} = id
switch pkValue.Type().Kind() { switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
v = int(id) v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
v = uint(id) v = uint(id)
} }
pkValue.Set(reflect.ValueOf(v)) pkValue.Set(reflect.ValueOf(v))
return id, nil return id, nil
} else {
sql = sql + " RETURNING (id)"
res, err := session.query(sql, args...)
if err != nil {
return 0, err
}
if len(res) < 1 {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
pkValue := table.PKColumn().ValueOf(bean)
if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() {
return 0, nil
}
idByte := res[0][table.PrimaryKey]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil {
return 0, err
}
var v interface{} = id
switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
v = uint(id)
}
pkValue.Set(reflect.ValueOf(v))
return id, nil
}
} }
// Method InsertOne insert only one struct into database as a record. // Method InsertOne insert only one struct into database as a record.
@ -1864,17 +1861,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
session.Statement.RefTable = table session.Statement.RefTable = table
if session.Statement.ColumnStr == "" { if session.Statement.ColumnStr == "" {
colNames, args = buildConditions(session.Engine, table, bean) colNames, args = buildConditions(session.Engine, table, bean, false)
} else { } else {
colNames, args, err = table.genCols(session, bean, true, true) colNames, args, err = table.genCols(session, bean, true, true)
if err != nil { if err != nil {
return 0, err return 0, err
} }
} }
if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now())
}
} else if t.Kind() == reflect.Map { } else if t.Kind() == reflect.Map {
if session.Statement.RefTable == nil { if session.Statement.RefTable == nil {
return 0, ErrTableNotFound return 0, ErrTableNotFound
@ -1888,19 +1881,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?") colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface()) args = append(args, bValue.MapIndex(v).Interface())
} }
if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now())
}
} else { } else {
return 0, ErrParamsType return 0, ErrParamsType
} }
if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now())
}
var condiColNames []string var condiColNames []string
var condiArgs []interface{} var condiArgs []interface{}
if len(condiBean) > 0 { if len(condiBean) > 0 {
condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0]) condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true)
} }
var condition = "" var condition = ""
@ -1920,10 +1914,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
sql := fmt.Sprintf("UPDATE %v SET %v %v", var sql string
session.Engine.Quote(session.Statement.TableName()), if table.Version != "" {
strings.Join(colNames, ", "), sql = fmt.Sprintf("UPDATE %v SET %v, %v %v",
condition) session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
condition)
} else {
sql = fmt.Sprintf("UPDATE %v SET %v %v",
session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
condition)
}
args = append(append(args, st.Params...), condiArgs...) args = append(append(args, st.Params...), condiArgs...)
@ -2002,7 +2005,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
table := session.Engine.AutoMap(bean) table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := buildConditions(session.Engine, table, bean) colNames, args := buildConditions(session.Engine, table, bean, true)
var condition = "" var condition = ""
if session.Statement.WhereStr != "" { if session.Statement.WhereStr != "" {

View File

@ -78,15 +78,22 @@ func (statement *Statement) Table(tableNameOrBean interface{}) {
} }
} }
func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) { // Auto generating conditions according a struct
func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool) ([]string, []interface{}) {
colNames := make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if !includeVersion && col.IsVersion {
continue
}
fieldValue := col.ValueOf(bean) fieldValue := col.ValueOf(bean)
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())
var val interface{} var val interface{}
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Bool: case reflect.Bool:
continue
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
val = fieldValue.Interface() val = fieldValue.Interface()
case reflect.String: case reflect.String:
if fieldValue.String() == "" { if fieldValue.String() == "" {
@ -364,7 +371,7 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.AutoMap(bean) table := statement.Engine.AutoMap(bean)
statement.RefTable = table statement.RefTable = table
colNames, args := buildConditions(statement.Engine, table, bean) colNames, args := buildConditions(statement.Engine, table, bean, true)
statement.ConditionStr = strings.Join(colNames, " and ") statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args statement.BeanArgs = args
@ -401,7 +408,7 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
table := statement.Engine.AutoMap(bean) table := statement.Engine.AutoMap(bean)
statement.RefTable = table statement.RefTable = table
colNames, args := buildConditions(statement.Engine, table, bean) colNames, args := buildConditions(statement.Engine, table, bean, true)
statement.ConditionStr = strings.Join(colNames, " and ") statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args statement.BeanArgs = args
var id string = "*" var id string = "*"

View File

@ -211,6 +211,7 @@ type Column struct {
IsCreated bool IsCreated bool
IsUpdated bool IsUpdated bool
IsCascade bool IsCascade bool
IsVersion bool
} }
func (col *Column) String(d dialect) string { func (col *Column) String(d dialect) string {
@ -264,6 +265,7 @@ type Table struct {
PrimaryKey string PrimaryKey string
Created string Created string
Updated string Updated string
Version string
Cacher Cacher Cacher Cacher
} }
@ -283,6 +285,9 @@ func (table *Table) AddColumn(col *Column) {
if col.IsUpdated { if col.IsUpdated {
table.Updated = col.Name table.Updated = col.Name
} }
if col.IsVersion {
table.Version = col.Name
}
} }
func (table *Table) AddIndex(index *Index) { func (table *Table) AddIndex(index *Index) {