add lib/pq support; bool on struct will not be as condition;

This commit is contained in:
Lunny Xiao 2013-11-06 15:36:38 +08:00
parent 0c541edf3f
commit 31de5d612e
7 changed files with 278 additions and 158 deletions

View File

@ -8,7 +8,7 @@ Xorm is a simple and powerful ORM for Go. It makes dabatabse operating simple.
## Discuss
Message me on [G+](https://plus.google.com/u/0/106406879480103142585)
Please visit [xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm)
## Drivers Support
@ -20,6 +20,8 @@ Drivers for Go's sql package which currently support database/sql includes:
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/bylevel/pq](https://github.com/lib/pq)
* Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq)
@ -40,19 +42,23 @@ Drivers for Go's sql package which currently support database/sql includes:
## Features
* Struct<->Table Mapping Supports, both name mapping and filed tags mapping
* Struct<->Table Mapping Supports, both name mapping and filed tag mapping
* Database Transaction Support
* Database Transaction Support
* Both ORM and SQL Operation Support
* Both ORM and SQL Operation Support
* Simply usage
* Simply chainable usage
* Support Id, In, Where, Limit, Join, Having, Sql functions and sturct as query conditions
* Support simple cascade load just like Hibernate for Java
* Support Id, In, Where, Limit, Join, Having, Sql functions and sturct as query conditions
* Code generator support, See [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md)
* Cache Support
* Reverse Tool Support
* Simple cascade load support
* Database Reverse Tool support, See [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md)
## Installing xorm

View File

@ -97,6 +97,9 @@ func insert(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
if user.Uid <= 0 {
t.Error(errors.New("not return id error"))
}
}
func testQuery(engine *Engine, t *testing.T) {
@ -149,6 +152,9 @@ func insertAutoIncr(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
if user.Uid <= 0 {
t.Error(errors.New("not return id error"))
}
}
func insertMulti(engine *Engine, t *testing.T) {
@ -159,11 +165,14 @@ func insertMulti(engine *Engine, t *testing.T) {
{Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
_, err := engine.Insert(&users)
id, err := engine.Insert(&users)
if err != nil {
t.Error(err)
panic(err)
}
if id <= 0 {
t.Error(errors.New("not return id error"))
}
users2 := []*Userinfo{
&Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
@ -172,11 +181,15 @@ func insertMulti(engine *Engine, t *testing.T) {
&Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
_, err = engine.Insert(&users2)
id, err = engine.Insert(&users2)
if err != nil {
t.Error(err)
panic(err)
}
if id <= 0 {
t.Error(errors.New("not return id error"))
}
}
func insertTwoTable(engine *Engine, t *testing.T) {
@ -245,6 +258,18 @@ func testdelete(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
_, err = engine.Id(2).Get(&user)
if err != nil {
t.Error(err)
panic(err)
}
_, err = engine.Delete(&user)
if err != nil {
t.Error(err)
panic(err)
}
}
func get(engine *Engine, t *testing.T) {
@ -1156,6 +1181,66 @@ func testStrangeName(engine *Engine, t *testing.T) {
}
}
type Version struct {
Id int64
Name string
Ver int `xorm:"version"`
}
func testVersion(engine *Engine, t *testing.T) {
err := engine.DropTables(new(Version))
if err != nil {
t.Error(err)
return
}
err = engine.CreateTables(new(Version))
if err != nil {
t.Error(err)
return
}
ver := &Version{Name: "sfsfdsfds"}
_, err = engine.Cols("name").Insert(ver)
if err != nil {
t.Error(err)
return
}
newVer := new(Version)
has, err := engine.Id(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
return
}
if !has {
t.Error(errors.New(fmt.Sprintf("no version id is %v", ver.Id)))
return
}
newVer.Name = "-------"
_, err = engine.Id(ver.Id).Update(newVer, &Version{Ver: newVer.Ver})
if err != nil {
t.Error(err)
return
}
has, err = engine.Id(ver.Id).Get(newVer)
if err != nil {
t.Error(err)
return
}
fmt.Println(ver)
newVer.Name = "-------"
_, err = engine.Id(ver.Id).Update(newVer, &Version{Ver: newVer.Ver})
if err != nil {
t.Error(err)
return
}
}
func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t)
@ -1240,6 +1325,8 @@ func testAll2(engine *Engine, t *testing.T) {
testIterate(engine, t)
fmt.Println("-------------- testStrangeName --------------")
testStrangeName(engine, t)
fmt.Println("-------------- testVersion --------------")
testVersion(engine, t)
fmt.Println("-------------- transaction --------------")
transaction(engine, t)
}

View File

@ -138,38 +138,48 @@ func (engine *Engine) Close() error {
return engine.Pool.Close(engine)
}
// Test if database is alive.
// Test method is deprecated, use Ping() method.
func (engine *Engine) Test() error {
return engine.Ping()
}
// Ping tests if database is alive.
func (engine *Engine) Ping() error {
session := engine.NewSession()
defer session.Close()
engine.LogSQL("PING DATABASE", engine.DriverName)
return session.Ping()
}
// logging sql
func (engine *Engine) LogSQL(contents ...interface{}) {
if engine.ShowSQL {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
// logging error
func (engine *Engine) LogError(contents ...interface{}) {
if engine.ShowErr {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
// logging debug
func (engine *Engine) LogDebug(contents ...interface{}) {
if engine.ShowDebug {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
// logging warn
func (engine *Engine) LogWarn(contents ...interface{}) {
if engine.ShowWarn {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
// execute sql
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
@ -182,6 +192,7 @@ func (engine *Engine) NoAutoTime() *Session {
return session.NoAutoTime()
}
// retrieve all tables, columns, indexes' informations from database.
func (engine *Engine) DBMetas() ([]*Table, error) {
tables, err := engine.dialect.GetTables()
if err != nil {
@ -215,30 +226,35 @@ func (engine *Engine) DBMetas() ([]*Table, error) {
return tables, nil
}
// use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Cascade(trueOrFalse...)
}
// Where method provide a condition query
func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Where(querystring, args...)
}
// Id mehtod provoide a condition as (id) = ?
func (engine *Engine) Id(id int64) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Id(id)
}
// set charset when create table, only support mysql now
func (engine *Engine) Charset(charset string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Charset(charset)
}
// set store engine when create table, only support mysql now
func (engine *Engine) StoreEngine(storeEngine string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
@ -257,12 +273,6 @@ func (engine *Engine) Omit(columns ...string) *Session {
return session.Omit(columns...)
}
/*func (engine *Engine) Trans(t string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Trans(t)
}*/
func (engine *Engine) In(column string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
@ -398,10 +408,11 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
col.Default = tags[j+1]
case k == "CREATED":
col.IsCreated = true
case k == "VERSION":
col.IsVersion = true
col.Default = "1"
case k == "UPDATED":
col.IsUpdated = true
/*case strings.HasPrefix(k, "--"):
col.Comment = k[2:len(k)]*/
case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
indexType = IndexType
indexName = k[len("INDEX")+1 : len(k)-1]
@ -487,7 +498,7 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
sqlType := Type2SQLType(fieldType)
col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false,
TWOSIDES, false, false, false}
TWOSIDES, false, false, false, false}
}
if col.IsAutoIncrement {
col.Nullable = false

View File

@ -2,7 +2,8 @@ package xorm
import (
"fmt"
_ "github.com/bylevel/pq"
//_ "github.com/bylevel/pq"
_ "github.com/lib/pq"
"testing"
)

View File

@ -850,7 +850,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
session.Statement.RefTable = table
if len(condiBean) > 0 {
colNames, args := buildConditions(session.Engine, table, condiBean[0])
colNames, args := buildConditions(session.Engine, table, condiBean[0], true)
session.Statement.ConditionStr = strings.Join(colNames, " and ")
session.Statement.BeanArgs = args
}
@ -1185,77 +1185,6 @@ func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice
session.Engine.LogSQL(paramStr)
return query(session.Db, sql, paramStr...)
/*s, err := session.Db.Prepare(sql)
if err != nil {
return nil, err
}
defer s.Close()
res, err := s.Query(paramStr...)
if err != nil {
return nil, err
}
defer res.Close()
fields, err := res.Columns()
if err != nil {
return nil, err
}
for res.Next() {
result := make(map[string][]byte)
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers = append(scanResultContainers, &scanResultContainer)
}
if err := res.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
aa := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
var str string
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
result[key] = []byte(str)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str)
case reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
result[key] = rawValue.Interface().([]byte)
default:
session.Engine.LogError("Unsupported type")
}
case reflect.String:
str = vv.String()
result[key] = []byte(str)
//时间类型
case reflect.Struct:
if aa.String() == "time.Time" {
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str)
} else {
session.Engine.LogError("Unsupported struct type")
}
default:
session.Engine.LogError("Unsupported type")
}
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil*/
}
func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
@ -1315,7 +1244,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return -1, errors.New("needs a pointer to a slice")
return 0, errors.New("needs a pointer to a slice")
}
bean := sliceValue.Index(0).Interface()
@ -1393,7 +1322,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
}
statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
statement := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
session.Engine.QuoteStr(),
session.Statement.TableName(),
session.Engine.QuoteStr(),
@ -1402,22 +1331,50 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
session.Engine.QuoteStr(),
strings.Join(colMultiPlaces, "),("))
res, err := session.exec(statement, args...)
if err != nil {
return -1, err
if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" {
res, err := session.exec(statement, args...)
if err != nil {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
if table.PrimaryKey != "" {
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
return id, nil
} else {
return 0, err
}
} else {
statement += " RETURNING (id)"
res, err := session.query(statement, args...)
if err != nil {
return 0, err
}
if len(res) < 1 {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
idByte := res[0][table.PrimaryKey]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil {
return 0, err
}
return id, nil
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
id, err := res.LastInsertId()
if err != nil {
return -1, err
}
return id, nil
}
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
@ -1644,7 +1601,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
colPlaces := strings.Repeat("?, ", len(colNames))
colPlaces = colPlaces[0 : len(colPlaces)-2]
sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v);",
sql := fmt.Sprintf("INSERT INTO %v%v%v (%v%v%v) VALUES (%v)",
session.Engine.QuoteStr(),
session.Statement.TableName(),
session.Engine.QuoteStr(),
@ -1653,40 +1610,80 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.Engine.QuoteStr(),
colPlaces)
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
// for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself.
if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" {
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
if table.PrimaryKey == "" {
return 0, nil
}
if table.PrimaryKey == "" {
return 0, nil
}
var id int64 = 0
pkValue := table.PKColumn().ValueOf(bean)
if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() {
return 0, nil
}
var id int64 = 0
pkValue := table.PKColumn().ValueOf(bean)
if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() {
return 0, nil
}
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return 0, err
}
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return 0, err
}
var v interface{} = id
switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
v = uint(id)
}
pkValue.Set(reflect.ValueOf(v))
var v interface{} = id
switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
v = uint(id)
}
pkValue.Set(reflect.ValueOf(v))
return id, nil
return id, nil
} else {
sql = sql + " RETURNING (id)"
res, err := session.query(sql, args...)
if err != nil {
return 0, err
}
if len(res) < 1 {
return 0, err
}
if table.Cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
pkValue := table.PKColumn().ValueOf(bean)
if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() {
return 0, nil
}
idByte := res[0][table.PrimaryKey]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil {
return 0, err
}
var v interface{} = id
switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
v = uint(id)
}
pkValue.Set(reflect.ValueOf(v))
return id, nil
}
}
// Method InsertOne insert only one struct into database as a record.
@ -1864,17 +1861,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
session.Statement.RefTable = table
if session.Statement.ColumnStr == "" {
colNames, args = buildConditions(session.Engine, table, bean)
colNames, args = buildConditions(session.Engine, table, bean, false)
} else {
colNames, args, err = table.genCols(session, bean, true, true)
if err != nil {
return 0, err
}
}
if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now())
}
} else if t.Kind() == reflect.Map {
if session.Statement.RefTable == nil {
return 0, ErrTableNotFound
@ -1888,19 +1881,20 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
colNames = append(colNames, session.Engine.Quote(v.String())+" = ?")
args = append(args, bValue.MapIndex(v).Interface())
}
if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now())
}
} else {
return 0, ErrParamsType
}
if session.Statement.UseAutoTime && table.Updated != "" {
colNames = append(colNames, session.Engine.Quote(table.Updated)+" = ?")
args = append(args, time.Now())
}
var condiColNames []string
var condiArgs []interface{}
if len(condiBean) > 0 {
condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0])
condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true)
}
var condition = ""
@ -1920,10 +1914,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
sql := fmt.Sprintf("UPDATE %v SET %v %v",
session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
condition)
var sql string
if table.Version != "" {
sql = fmt.Sprintf("UPDATE %v SET %v, %v %v",
session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
condition)
} else {
sql = fmt.Sprintf("UPDATE %v SET %v %v",
session.Engine.Quote(session.Statement.TableName()),
strings.Join(colNames, ", "),
condition)
}
args = append(append(args, st.Params...), condiArgs...)
@ -2002,7 +2005,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
table := session.Engine.AutoMap(bean)
session.Statement.RefTable = table
colNames, args := buildConditions(session.Engine, table, bean)
colNames, args := buildConditions(session.Engine, table, bean, true)
var condition = ""
if session.Statement.WhereStr != "" {

View File

@ -78,15 +78,22 @@ func (statement *Statement) Table(tableNameOrBean interface{}) {
}
}
func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) {
// Auto generating conditions according a struct
func buildConditions(engine *Engine, table *Table, bean interface{}, includeVersion bool) ([]string, []interface{}) {
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if !includeVersion && col.IsVersion {
continue
}
fieldValue := col.ValueOf(bean)
fieldType := reflect.TypeOf(fieldValue.Interface())
var val interface{}
switch fieldType.Kind() {
case reflect.Bool:
continue
// if a bool in a struct, it will not be as a condition because it default is false,
// please use Where() instead
val = fieldValue.Interface()
case reflect.String:
if fieldValue.String() == "" {
@ -364,7 +371,7 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
table := statement.Engine.AutoMap(bean)
statement.RefTable = table
colNames, args := buildConditions(statement.Engine, table, bean)
colNames, args := buildConditions(statement.Engine, table, bean, true)
statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
@ -401,7 +408,7 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
table := statement.Engine.AutoMap(bean)
statement.RefTable = table
colNames, args := buildConditions(statement.Engine, table, bean)
colNames, args := buildConditions(statement.Engine, table, bean, true)
statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
var id string = "*"

View File

@ -211,6 +211,7 @@ type Column struct {
IsCreated bool
IsUpdated bool
IsCascade bool
IsVersion bool
}
func (col *Column) String(d dialect) string {
@ -264,6 +265,7 @@ type Table struct {
PrimaryKey string
Created string
Updated string
Version string
Cacher Cacher
}
@ -283,6 +285,9 @@ func (table *Table) AddColumn(col *Column) {
if col.IsUpdated {
table.Updated = col.Name
}
if col.IsVersion {
table.Version = col.Name
}
}
func (table *Table) AddIndex(index *Index) {