Merge remote-tracking branch 'refs/remotes/go-xorm/master'

Conflicts:
	session.go
This commit is contained in:
hzmnet 2015-11-03 01:04:49 +08:00
commit fe502f7ae6
16 changed files with 608 additions and 278 deletions

View File

@ -1,4 +1,4 @@
Copyright (c) 2013 - 2015 Copyright (c) 2013 - 2015 The Xorm Authors
All rights reserved. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without

View File

@ -2,6 +2,8 @@
Xorm is a simple and powerful ORM for Go. Xorm is a simple and powerful ORM for Go.
[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
[![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") [![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge")
# Features # Features
@ -33,18 +35,29 @@ Drivers for Go's sql package which currently support database/sql includes:
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/lib/pq](https://github.com/lib/pq) * Postgres: [github.com/lib/pq](https://github.com/lib/pq)
* Tidb: [github.com/pingcap/tidb](https://github.com/pingcap/tidb)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) * MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb)
* MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc)
* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) * Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment)
* ql: [github.com/cznic/ql](https://github.com/cznic/ql) (experiment)
# Changelog # Changelog
* **v0.4.4**
* ql database expriment support
* tidb database expriment support
* sql.NullString and etc. field support
* select ForUpdate support
* many bugs fixed
* **v0.4.3** * **v0.4.3**
* Json column type support * Json column type support
* oracle expirement support * oracle expirement support

View File

@ -4,6 +4,8 @@
xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。
[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
[![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm)
## 特性 ## 特性
@ -34,28 +36,34 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/lib/pq](https://github.com/lib/pq) * Postgres: [github.com/lib/pq](https://github.com/lib/pq)
* Tidb: [github.com/pingcap/tidb](https://github.com/pingcap/tidb)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) * MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb)
* MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc)
* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持) * Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持)
* ql: [github.com/cznic/ql](https://github.com/cznic/ql) (试验性支持)
## 更新日志 ## 更新日志
* **v0.4.4**
* Tidb 数据库支持
* QL 试验性支持
* sql.NullString支持
* ForUpdate 支持
* bug修正
* **v0.4.3** * **v0.4.3**
* Json 字段类型支持 * Json 字段类型支持
* oracle实验性支持 * oracle实验性支持
* bug修正 * bug修正
* **v0.4.2**
* 事物如未Rollback或Commit在关闭时会自动Rollback
* Gonic 映射支持
* bug修正
[更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16) [更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16)
## 安装 ## 安装

View File

@ -1 +1 @@
xorm v0.4.3.0520 xorm v0.4.4.1029

View File

@ -664,10 +664,12 @@ func (engine *Engine) autoMapType(v reflect.Value) *core.Table {
if !ok { if !ok {
table = engine.mapType(v) table = engine.mapType(v)
engine.Tables[t] = table engine.Tables[t] = table
if v.CanAddr() { if engine.Cacher != nil {
engine.GobRegister(v.Addr().Interface()) if v.CanAddr() {
} else { engine.GobRegister(v.Addr().Interface())
engine.GobRegister(v.Interface()) } else {
engine.GobRegister(v.Interface())
}
} }
} }
engine.mutex.Unlock() engine.mutex.Unlock()

Binary file not shown.

View File

@ -33,7 +33,7 @@ func test(engine *xorm.Engine) {
return return
} }
size := 500 size := 100
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
@ -83,7 +83,7 @@ func test(engine *xorm.Engine) {
} }
func main() { func main() {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(2)
fmt.Println("-----start sqlite go routines-----") fmt.Println("-----start sqlite go routines-----")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {

View File

@ -220,6 +220,11 @@ func (db *mssql) SqlType(c *core.Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bool: case core.Bool:
res = core.TinyInt res = core.TinyInt
if c.Default == "true" {
c.Default = "1"
} else if c.Default == "false" {
c.Default = "0"
}
case core.Serial: case core.Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
@ -504,6 +509,10 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars
return sql return sql
} }
func (db *mssql) ForUpdateSql(query string) string {
return query
}
func (db *mssql) Filters() []core.Filter { func (db *mssql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}}
} }

View File

@ -913,7 +913,8 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName} pgSchema := "public"
args := []interface{}{tableName,pgSchema}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -924,7 +925,7 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1 AND f.attnum > 0 ORDER BY f.attnum;` WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
if db.Logger != nil { if db.Logger != nil {

View File

@ -23,6 +23,10 @@ type BeforeSetProcessor interface {
BeforeSet(string, Cell) BeforeSet(string, Cell)
} }
type AfterSetProcessor interface {
AfterSet(string, Cell)
}
// !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations
//// Executed before an object is validated //// Executed before an object is validated
//type BeforeValidateProcessor interface { //type BeforeValidateProcessor interface {

View File

@ -45,7 +45,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable) sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable)
} }
rows.session.Engine.logSQL(sqlStr, args) rows.session.saveLastSQL(sqlStr, args)
var err error var err error
rows.stmt, err = rows.session.DB().Prepare(sqlStr) rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -41,6 +42,10 @@ type Session struct {
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr))
cascadeDeep int cascadeDeep int
// !evalphobia! stored the last executed query on this session
lastSQL string
lastSQLArgs []interface{}
} }
// Method Init reset the session as the init status. // Method Init reset the session as the init status.
@ -58,6 +63,9 @@ func (session *Session) Init() {
session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0)
session.beforeClosures = make([]func(interface{}), 0) session.beforeClosures = make([]func(interface{}), 0)
session.afterClosures = make([]func(interface{}), 0) session.afterClosures = make([]func(interface{}), 0)
session.lastSQL = ""
session.lastSQLArgs = []interface{}{}
} }
// Method Close release the connection from pool // Method Close release the connection from pool
@ -203,6 +211,12 @@ func (session *Session) Distinct(columns ...string) *Session {
return session return session
} }
// Set Read/Write locking for UPDATE
func (session *Session) ForUpdate() *Session {
session.Statement.IsForUpdate = true
return session
}
// Only not use the paramters as select or update columns // Only not use the paramters as select or update columns
func (session *Session) Omit(columns ...string) *Session { func (session *Session) Omit(columns ...string) *Session {
session.Statement.Omit(columns...) session.Statement.Omit(columns...)
@ -304,8 +318,7 @@ func (session *Session) Begin() error {
session.IsAutoCommit = false session.IsAutoCommit = false
session.IsCommitedOrRollbacked = false session.IsCommitedOrRollbacked = false
session.Tx = tx session.Tx = tx
session.saveLastSQL("BEGIN TRANSACTION")
session.Engine.logSQL("BEGIN TRANSACTION")
} }
return nil return nil
} }
@ -313,7 +326,7 @@ func (session *Session) Begin() error {
// When using transaction, you can rollback if any error // When using transaction, you can rollback if any error
func (session *Session) Rollback() error { func (session *Session) Rollback() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
session.Engine.logSQL(session.Engine.dialect.RollBackStr()) session.saveLastSQL(session.Engine.dialect.RollBackStr())
session.IsCommitedOrRollbacked = true session.IsCommitedOrRollbacked = true
return session.Tx.Rollback() return session.Tx.Rollback()
} }
@ -323,7 +336,7 @@ func (session *Session) Rollback() error {
// When using transaction, Commit will commit all operations. // When using transaction, Commit will commit all operations.
func (session *Session) Commit() error { func (session *Session) Commit() error {
if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked {
session.Engine.logSQL("COMMIT") session.saveLastSQL("COMMIT")
session.IsCommitedOrRollbacked = true session.IsCommitedOrRollbacked = true
var err error var err error
if err = session.Tx.Commit(); err == nil { if err = session.Tx.Commit(); err == nil {
@ -444,7 +457,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
session.Engine.logSQL(sqlStr, args...) session.saveLastSQL(sqlStr, args...)
return session.Engine.LogSQLExecutionTime(sqlStr, args, func() (sql.Result, error) { return session.Engine.LogSQLExecutionTime(sqlStr, args, func() (sql.Result, error) {
if session.IsAutoCommit { if session.IsAutoCommit {
@ -587,11 +600,15 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
return nil return nil
} }
func (statement *Statement) JoinColumns(cols []*core.Column) string { func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string {
var colnames = make([]string, len(cols)) var colnames = make([]string, len(cols))
for i, col := range cols { for i, col := range cols {
colnames[i] = statement.Engine.Quote(statement.TableName()) + if includeTableName {
"." + statement.Engine.Quote(col.Name) colnames[i] = statement.Engine.Quote(statement.TableName()) +
"." + statement.Engine.Quote(col.Name)
} else {
colnames[i] = statement.Engine.Quote(col.Name)
}
} }
return strings.Join(colnames, ", ") return strings.Join(colnames, ", ")
} }
@ -603,16 +620,33 @@ func (statement *Statement) convertIdSql(sqlStr string) string {
return "" return ""
} }
colstrs := statement.JoinColumns(cols) colstrs := statement.JoinColumns(cols, false)
sqls := splitNNoCase(sqlStr, "from", 2) sqls := splitNNoCase(sqlStr, " from ", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
return "" return ""
} }
if statement.Engine.dialect.DBType() == "ql" {
return fmt.Sprintf("SELECT id() FROM %v", sqls[1])
}
return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
} }
return "" return ""
} }
<<<<<<< HEAD
=======
func (session *Session) canCache() bool {
if session.Statement.RefTable == nil ||
session.Statement.JoinStr != "" ||
session.Statement.RawSQL != "" ||
session.Tx != nil ||
len(session.Statement.selectStr) > 0 {
return false
}
return true
}
>>>>>>> refs/remotes/go-xorm/master
func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) {
// if has no reftable, then don't use cache currently // if has no reftable, then don't use cache currently
if session.Statement.RefTable == nil || if session.Statement.RefTable == nil ||
@ -715,7 +749,11 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
} }
func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) {
<<<<<<< HEAD
if session.Statement.RefTable == nil || if session.Statement.RefTable == nil ||
=======
if !session.canCache() ||
>>>>>>> refs/remotes/go-xorm/master
indexNoCase(sqlStr, "having") != -1 || indexNoCase(sqlStr, "having") != -1 ||
indexNoCase(sqlStr, "group by") != -1 { indexNoCase(sqlStr, "group by") != -1 {
return ErrCacheFailed return ErrCacheFailed
@ -1309,19 +1347,33 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} else { } else {
table = session.Statement.RefTable table = session.Statement.RefTable
} }
<<<<<<< HEAD
fmt.Println("sliceValue.Kind()") fmt.Println("sliceValue.Kind()")
if len(condiBean) > 0 { if len(condiBean) > 0 {
colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true,
false, true, session.Statement.allUseBool, session.Statement.useAllCols, false, true, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.unscoped, session.Statement.mustColumnMap) session.Statement.unscoped, session.Statement.mustColumnMap)
=======
var addedTableName = (len(session.Statement.JoinStr) > 0)
if len(condiBean) > 0 {
colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true,
false, true, session.Statement.allUseBool, session.Statement.useAllCols,
session.Statement.unscoped, session.Statement.mustColumnMap,
session.Statement.TableName(), addedTableName)
>>>>>>> refs/remotes/go-xorm/master
session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.ConditionStr = strings.Join(colNames, " AND ")
session.Statement.BeanArgs = args session.Statement.BeanArgs = args
} else { } else {
// !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given. // !oinume! Add "<col> IS NULL" to WHERE whatever condiBean is given.
// See https://github.com/go-xorm/xorm/issues/179 // See https://github.com/go-xorm/xorm/issues/179
if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled
var colName string = session.Engine.Quote(col.Name)
if addedTableName {
colName = session.Engine.Quote(session.Statement.TableName()) + "." + colName
}
session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00') ", session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00') ",
session.Engine.Quote(col.Name), session.Engine.Quote(col.Name)) colName, colName)
} }
} }
fmt.Println("sliceValue.Kind()") fmt.Println("sliceValue.Kind()")
@ -1437,7 +1489,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
table := session.Engine.autoMapType(dataStruct) table := session.Engine.autoMapType(dataStruct)
return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc) return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc)
} else { } else {
@ -1583,7 +1634,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64 var total int64
sql := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName)) sql := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName))
err := session.DB().QueryRow(sql).Scan(&total) err := session.DB().QueryRow(sql).Scan(&total)
session.Engine.logSQL(sql) session.saveLastSQL(sql)
if err != nil { if err != nil {
return true, err return true, err
} }
@ -1752,6 +1803,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
} }
} }
defer func() {
if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet {
for ii, key := range fields {
b.AfterSet(key, Cell(scanResults[ii].(*interface{})))
}
}
}()
var tempMap = make(map[string]int) var tempMap = make(map[string]int)
for ii, key := range fields { for ii, key := range fields {
var idx int var idx int
@ -1801,7 +1860,6 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
hasAssigned := false hasAssigned := false
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
if rawValueType.Kind() == reflect.String { if rawValueType.Kind() == reflect.String {
hasAssigned = true hasAssigned = true
@ -1812,6 +1870,15 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
return err return err
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
session.Engine.LogError(err)
return err
}
fieldValue.Set(x.Elem())
} }
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
switch rawValueType.Kind() { switch rawValueType.Kind() {
@ -1856,6 +1923,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
fieldValue.SetUint(uint64(vv.Int())) fieldValue.SetUint(uint64(vv.Int()))
} }
case reflect.Struct: case reflect.Struct:
col := table.GetColumn(key)
if fieldType.ConvertibleTo(core.TimeType) { if fieldType.ConvertibleTo(core.TimeType) {
if rawValueType == core.TimeType { if rawValueType == core.TimeType {
hasAssigned = true hasAssigned = true
@ -1863,7 +1931,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
t := vv.Convert(core.TimeType).Interface().(time.Time) t := vv.Convert(core.TimeType).Interface().(time.Time)
z, _ := t.Zone() z, _ := t.Zone()
if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location
session.Engine.LogDebug("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) session.Engine.LogDebugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location())
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(),
t.Minute(), t.Second(), t.Nanosecond(), time.Local) t.Minute(), t.Second(), t.Nanosecond(), time.Local)
} }
@ -1881,13 +1949,42 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount
vv = reflect.ValueOf(t) vv = reflect.ValueOf(t)
fieldValue.Set(vv) fieldValue.Set(vv)
} }
} else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
// !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
hasAssigned = true
if err := nulVal.Scan(vv.Interface()); err != nil {
//fmt.Println("sql.Sanner error:", err.Error())
session.Engine.LogError("sql.Sanner error:", err.Error())
hasAssigned = false
}
} else if col.SQLType.IsJson() {
if rawValueType.Kind() == reflect.String {
hasAssigned = true
x := reflect.New(fieldType)
err := json.Unmarshal([]byte(vv.String()), x.Interface())
if err != nil {
session.Engine.LogError(err)
return err
}
fieldValue.Set(x.Elem())
} else if rawValueType.Kind() == reflect.Slice {
hasAssigned = true
x := reflect.New(fieldType)
err := json.Unmarshal(vv.Bytes(), x.Interface())
if err != nil {
session.Engine.LogError(err)
return err
}
fieldValue.Set(x.Elem())
}
} else if session.Statement.UseCascade { } else if session.Statement.UseCascade {
table := session.Engine.autoMapType(*fieldValue) table := session.Engine.autoMapType(*fieldValue)
if table != nil { if table != nil {
if len(table.PrimaryKeys) > 1 { if len(table.PrimaryKeys) != 1 {
panic("unsupported composited primary key cascade") panic("unsupported non or composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) var pk = make(core.PK, len(table.PrimaryKeys))
switch rawValueType.Kind() { switch rawValueType.Kind() {
case reflect.Int64: case reflect.Int64:
pk[0] = vv.Int() pk[0] = vv.Int()
@ -2075,7 +2172,7 @@ func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{})
*sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable) *sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
session.Engine.logSQL(*sqlStr, paramStr...) session.saveLastSQL(*sqlStr, paramStr...)
} }
func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
@ -2247,7 +2344,9 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
cols := make([]*core.Column, 0) cols := make([]*core.Column, 0)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
elemValue := sliceValue.Index(i).Interface() v := sliceValue.Index(i)
vv := reflect.Indirect(v)
elemValue := v.Interface()
colPlaces := make([]string, 0) colPlaces := make([]string, 0)
// handle BeforeInsertProcessor // handle BeforeInsertProcessor
@ -2263,7 +2362,11 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if i == 0 { if i == 0 {
for _, col := range table.Columns() { for _, col := range table.Columns() {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) ptrFieldValue, err := col.ValueOfV(&vv)
if err != nil {
return 0, err
}
fieldValue := *ptrFieldValue
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
@ -2306,7 +2409,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
} }
} else { } else {
for _, col := range cols { for _, col := range cols {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) ptrFieldValue, err := col.ValueOfV(&vv)
if err != nil {
return 0, err
}
fieldValue := *ptrFieldValue
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
@ -2370,7 +2478,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
lenAfterClosures := len(session.afterClosures) lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
elemValue := sliceValue.Index(i).Interface() elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor // handle AfterInsertProcessor
if session.IsAutoCommit { if session.IsAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
@ -2585,108 +2694,115 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value,
fieldValue.SetUint(x) fieldValue.SetUint(x)
//Currently only support Time type //Currently only support Time type
case reflect.Struct: case reflect.Struct:
if fieldType.ConvertibleTo(core.TimeType) { // !<winxxp>! 增加支持sql.Scanner接口的结构如sql.NullString
x, err := session.byte2Time(col, data) if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
if err != nil { if err := nulVal.Scan(data); err != nil {
return err return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error())
} }
v = x } else {
fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) if fieldType.ConvertibleTo(core.TimeType) {
} else if session.Statement.UseCascade { x, err := session.byte2Time(col, data)
table := session.Engine.autoMapType(*fieldValue) if err != nil {
if table != nil { return err
if len(table.PrimaryKeys) > 1 {
panic("unsupported composited primary key cascade")
} }
var pk = make(core.PK, len(table.PrimaryKeys)) v = x
rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) fieldValue.Set(reflect.ValueOf(v).Convert(fieldType))
switch rawValueType.Kind() { } else if session.Statement.UseCascade {
case reflect.Int64: table := session.Engine.autoMapType(*fieldValue)
x, err := strconv.ParseInt(string(data), 10, 64) if table != nil {
if err != nil { if len(table.PrimaryKeys) > 1 {
return fmt.Errorf("arg %v as int: %s", key, err.Error()) panic("unsupported composited primary key cascade")
} }
pk[0] = x var pk = make(core.PK, len(table.PrimaryKeys))
case reflect.Int: rawValueType := table.ColumnType(table.PKColumns()[0].FieldName)
x, err := strconv.ParseInt(string(data), 10, 64) switch rawValueType.Kind() {
if err != nil { case reflect.Int64:
return fmt.Errorf("arg %v as int: %s", key, err.Error()) x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = x
case reflect.Int:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int(x)
case reflect.Int32:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int32(x)
case reflect.Int16:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int16(x)
case reflect.Int8:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int8(x)
case reflect.Uint64:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = x
case reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint(x)
case reflect.Uint32:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint32(x)
case reflect.Uint16:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint16(x)
case reflect.Uint8:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint8(x)
case reflect.String:
pk[0] = string(data)
default:
panic("unsupported primary key type cascade")
} }
pk[0] = int(x)
case reflect.Int32:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int32(x)
case reflect.Int16:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int16(x)
case reflect.Int8:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = int8(x)
case reflect.Uint64:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = x
case reflect.Uint:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint(x)
case reflect.Uint32:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint32(x)
case reflect.Uint16:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint16(x)
case reflect.Uint8:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return fmt.Errorf("arg %v as int: %s", key, err.Error())
}
pk[0] = uint8(x)
case reflect.String:
pk[0] = string(data)
default:
panic("unsupported primary key type cascade")
}
if !isPKZero(pk) { if !isPKZero(pk) {
// !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch
// however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne
// property to be fetched lazily // property to be fetched lazily
structInter := reflect.New(fieldValue.Type()) structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession() newsession := session.Engine.NewSession()
defer newsession.Close() defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return err return err
} }
if has { if has {
v = structInter.Elem().Interface() v = structInter.Elem().Interface()
fieldValue.Set(reflect.ValueOf(v)) fieldValue.Set(reflect.ValueOf(v))
} else { } else {
return errors.New("cascade obj is not exist!") return errors.New("cascade obj is not exist!")
}
} }
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
} }
} else {
return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String())
} }
} }
case reflect.Ptr: case reflect.Ptr:
@ -3110,16 +3226,45 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return fieldValue.Interface(), nil return fieldValue.Interface(), nil
} }
} }
if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
if !col.SQLType.IsJson() {
// !<winxxp>! 增加支持driver.Valuer接口的结构如sql.NullString
if v, ok := fieldValue.Interface().(driver.Valuer); ok {
return v.Value()
}
fieldTable := session.Engine.autoMapType(fieldValue)
//if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok {
if len(fieldTable.PrimaryKeys) == 1 { if len(fieldTable.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName)
return pkField.Interface(), nil return pkField.Interface(), nil
} else {
return 0, fmt.Errorf("no primary key for col %v", col.Name)
} }
<<<<<<< HEAD
} else { } else {
return 0, fmt.Errorf("Unsupported type %v\n", fieldValue.Type()) return 0, fmt.Errorf("Unsupported type %v\n", fieldValue.Type())
=======
return 0, fmt.Errorf("no primary key for col %v", col.Name)
//}
>>>>>>> refs/remotes/go-xorm/master
} }
if col.SQLType.IsText() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogError(err)
return 0, err
}
return string(bytes), nil
} else if col.SQLType.IsBlob() {
bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
session.Engine.LogError(err)
return 0, err
}
return bytes, nil
}
return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type())
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
bytes, err := json.Marshal(fieldValue.Interface()) bytes, err := json.Marshal(fieldValue.Interface())
if err != nil { if err != nil {
@ -3153,9 +3298,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
} }
} }
return bytes, nil return bytes, nil
} else {
return nil, ErrUnSupportedType
} }
return nil, ErrUnSupportedType
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return int64(fieldValue.Uint()), nil return int64(fieldValue.Uint()), nil
default: default:
@ -3177,12 +3321,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
processor.BeforeInsert() processor.BeforeInsert()
} }
// -- // --
colNames, args, err := genCols(table, session, bean, false, false) colNames, args, err := genCols(table, session, bean, false, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
// insert expr columns, override if exists // insert expr columns, override if exists
exprColumns := session.Statement.getExpr() exprColumns := session.Statement.getExpr()
exprColVals := make([]string, 0, len(exprColumns)) exprColVals := make([]string, 0, len(exprColumns))
@ -3247,9 +3389,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != ""
res, err := session.query("select seq_atable.currval from dual", args...)
if session.Engine.DriverName() != core.POSTGRES || table.AutoIncrement == "" {
res, err := session.exec(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} else { } else {
@ -3269,14 +3412,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
} }
if table.AutoIncrement == "" { if len(res) < 1 {
return res.RowsAffected() return 0, errors.New("insert no error but not returned id")
} }
var id int64 = 0 idByte := res[0][table.AutoIncrement]
id, err = res.LastInsertId() id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil || id <= 0 { if err != nil {
return res.RowsAffected() return 1, err
} }
aiValue, err := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
@ -3284,8 +3427,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
session.Engine.LogError(err) session.Engine.LogError(err)
} }
if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { if aiValue == nil || !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() {
return res.RowsAffected() return 1, nil
} }
var v interface{} = id var v interface{} = id
@ -3303,8 +3446,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
aiValue.Set(reflect.ValueOf(v)) aiValue.Set(reflect.ValueOf(v))
return res.RowsAffected() return 1, nil
} else { } else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 {
//assert table.AutoIncrement != "" //assert table.AutoIncrement != ""
sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement)
res, err := session.query(sqlStr, args...) res, err := session.query(sqlStr, args...)
@ -3363,6 +3506,66 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
aiValue.Set(reflect.ValueOf(v)) aiValue.Set(reflect.ValueOf(v))
return 1, nil return 1, nil
} else {
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
} else {
handleAfterInsertProcessorFunc(bean)
}
if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName())
}
if table.Version != "" && session.Statement.checkVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.Engine.LogError(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1)
}
}
if table.AutoIncrement == "" {
return res.RowsAffected()
}
var id int64 = 0
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return res.RowsAffected()
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.Engine.LogError(err)
}
if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() {
return res.RowsAffected()
}
var v interface{} = id
switch aiValue.Type().Kind() {
case reflect.Int16:
v = int16(id)
case reflect.Int32:
v = int32(id)
case reflect.Int:
v = int(id)
case reflect.Uint16:
v = uint16(id)
case reflect.Uint32:
v = uint32(id)
case reflect.Uint64:
v = uint64(id)
case reflect.Uint:
v = uint(id)
}
aiValue.Set(reflect.ValueOf(v))
return res.RowsAffected()
} }
} }
@ -3383,7 +3586,7 @@ func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) {
return "", "" return "", ""
} }
colstrs := statement.JoinColumns(statement.RefTable.PKColumns()) colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true)
sqls := splitNNoCase(sqlStr, "where", 2) sqls := splitNNoCase(sqlStr, "where", 2)
if len(sqls) != 2 { if len(sqls) != 2 {
if len(sqls) == 1 { if len(sqls) == 1 {
@ -3592,7 +3795,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if session.Statement.ColumnStr == "" { if session.Statement.ColumnStr == "" {
colNames, args = buildUpdates(session.Engine, table, bean, false, false, colNames, args = buildUpdates(session.Engine, table, bean, false, false,
false, false, session.Statement.allUseBool, session.Statement.useAllCols, false, false, session.Statement.allUseBool, session.Statement.useAllCols,
<<<<<<< HEAD
session.Statement.mustColumnMap, session.Statement.columnMap, true) session.Statement.mustColumnMap, session.Statement.columnMap, true)
=======
session.Statement.mustColumnMap, session.Statement.nullableMap,
session.Statement.columnMap, true, session.Statement.unscoped)
>>>>>>> refs/remotes/go-xorm/master
} else { } else {
colNames, args, err = genCols(table, session, bean, true, true) colNames, args, err = genCols(table, session, bean, true, true)
if err != nil { if err != nil {
@ -3862,7 +4070,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
session.Statement.RefTable = table session.Statement.RefTable = table
colNames, args := buildConditions(session.Engine, table, bean, true, true, colNames, args := buildConditions(session.Engine, table, bean, true, true,
false, true, session.Statement.allUseBool, session.Statement.useAllCols, false, true, session.Statement.allUseBool, session.Statement.useAllCols,
<<<<<<< HEAD
session.Statement.unscoped, session.Statement.mustColumnMap) session.Statement.unscoped, session.Statement.mustColumnMap)
=======
session.Statement.unscoped, session.Statement.mustColumnMap,
session.Statement.TableName(), false)
>>>>>>> refs/remotes/go-xorm/master
var condition = "" var condition = ""
var andStr = session.Engine.dialect.AndStr() var andStr = session.Engine.dialect.AndStr()
@ -3966,6 +4179,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
return res.RowsAffected() return res.RowsAffected()
} }
// saveLastSQL stores executed query information
func (session *Session) saveLastSQL(sql string, args ...interface{}) {
session.lastSQL = sql
session.lastSQLArgs = args
session.Engine.logSQL(sql, args...)
}
// LastSQL returns last query information
func (session *Session) LastSQL() (string, []interface{}) {
return session.lastSQL, session.lastSQLArgs
}
func (s *Session) Sync2(beans ...interface{}) error { func (s *Session) Sync2(beans ...interface{}) error {
engine := s.Engine engine := s.Engine
@ -4054,6 +4279,7 @@ func (s *Session) Sync2(beans ...interface{}) error {
} }
var foundIndexNames = make(map[string]bool) var foundIndexNames = make(map[string]bool)
var addedNames = make(map[string]*core.Index)
for name, index := range table.Indexes { for name, index := range table.Indexes {
var oriIndex *core.Index var oriIndex *core.Index
@ -4077,20 +4303,7 @@ func (s *Session) Sync2(beans ...interface{}) error {
} }
if oriIndex == nil { if oriIndex == nil {
if index.Type == core.UniqueType { addedNames[name] = index
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addUnique(table.Name, name)
} else if index.Type == core.IndexType {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addIndex(table.Name, name)
}
if err != nil {
return err
}
} }
} }
@ -4103,6 +4316,23 @@ func (s *Session) Sync2(beans ...interface{}) error {
} }
} }
} }
for name, index := range addedNames {
if index.Type == core.UniqueType {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addUnique(table.Name, name)
} else if index.Type == core.IndexType {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addIndex(table.Name, name)
}
if err != nil {
return err
}
}
} }
} }

View File

@ -156,6 +156,13 @@ func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName st
func (db *sqlite3) SqlType(c *core.Column) string { func (db *sqlite3) SqlType(c *core.Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case core.Bool:
if c.Default == "true" {
c.Default = "1"
} else if c.Default == "false" {
c.Default = "0"
}
return core.Integer
case core.Date, core.DateTime, core.TimeStamp, core.Time: case core.Date, core.DateTime, core.TimeStamp, core.Time:
return core.DateTime return core.DateTime
case core.TimeStampz: case core.TimeStampz:
@ -163,7 +170,7 @@ func (db *sqlite3) SqlType(c *core.Column) string {
case core.Char, core.Varchar, core.NVarchar, core.TinyText, case core.Char, core.Varchar, core.NVarchar, core.TinyText,
core.Text, core.MediumText, core.LongText, core.Json: core.Text, core.MediumText, core.LongText, core.Json:
return core.Text return core.Text
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool: case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt:
return core.Integer return core.Integer
case core.Float, core.Double, core.Real: case core.Float, core.Double, core.Real:
return core.Real return core.Real
@ -243,6 +250,10 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string {
return fmt.Sprintf("DROP INDEX %v", quote(idxName)) return fmt.Sprintf("DROP INDEX %v", quote(idxName))
} }
func (db *sqlite3) ForUpdateSql(query string) string {
return query
}
/*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { /*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"

View File

@ -5,6 +5,8 @@
package xorm package xorm
import ( import (
"bytes"
"database/sql/driver"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -65,6 +67,7 @@ type Statement struct {
UseCache bool UseCache bool
UseAutoTime bool UseAutoTime bool
IsDistinct bool IsDistinct bool
IsForUpdate bool
TableAlias string TableAlias string
allUseBool bool allUseBool bool
checkVersion bool checkVersion bool
@ -101,6 +104,7 @@ func (statement *Statement) Init() {
statement.UseCache = true statement.UseCache = true
statement.UseAutoTime = true statement.UseAutoTime = true
statement.IsDistinct = false statement.IsDistinct = false
statement.IsForUpdate = false
statement.TableAlias = "" statement.TableAlias = ""
statement.selectStr = "" statement.selectStr = ""
statement.allUseBool = false statement.allUseBool = false
@ -140,9 +144,11 @@ func (statement *Statement) Where(querystring string, args ...interface{}) *Stat
// add Where & and statment // add Where & and statment
func (statement *Statement) And(querystring string, args ...interface{}) *Statement { func (statement *Statement) And(querystring string, args ...interface{}) *Statement {
if statement.WhereStr != "" { if len(statement.WhereStr) > 0 {
statement.WhereStr = fmt.Sprintf("(%v) %s (%v)", statement.WhereStr, var buf bytes.Buffer
fmt.Fprintf(&buf, "(%v) %s (%v)", statement.WhereStr,
statement.Engine.dialect.AndStr(), querystring) statement.Engine.dialect.AndStr(), querystring)
statement.WhereStr = buf.String()
} else { } else {
statement.WhereStr = querystring statement.WhereStr = querystring
} }
@ -152,9 +158,11 @@ func (statement *Statement) And(querystring string, args ...interface{}) *Statem
// add Where & Or statment // add Where & Or statment
func (statement *Statement) Or(querystring string, args ...interface{}) *Statement { func (statement *Statement) Or(querystring string, args ...interface{}) *Statement {
if statement.WhereStr != "" { if len(statement.WhereStr) > 0 {
statement.WhereStr = fmt.Sprintf("(%v) %s (%v)", statement.WhereStr, var buf bytes.Buffer
fmt.Fprintf(&buf, "(%v) %s (%v)", statement.WhereStr,
statement.Engine.dialect.OrStr(), querystring) statement.Engine.dialect.OrStr(), querystring)
statement.WhereStr = buf.String()
} else { } else {
statement.WhereStr = querystring statement.WhereStr = querystring
} }
@ -179,7 +187,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, includeAutoIncr bool, allUseBool bool, useAllCols bool,
mustColumnMap map[string]bool, nullableMap map[string]bool, mustColumnMap map[string]bool, nullableMap map[string]bool,
columnMap map[string]bool, update bool) ([]string, []interface{}) { columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) {
colNames := make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
@ -196,7 +204,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
if !includeAutoIncr && col.IsAutoIncrement { if !includeAutoIncr && col.IsAutoIncrement {
continue continue
} }
if col.IsDeleted { if col.IsDeleted && !unscoped {
continue continue
} }
if use, ok := columnMap[col.Name]; ok && !use { if use, ok := columnMap[col.Name]; ok && !use {
@ -219,6 +227,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
requiredField := useAllCols requiredField := useAllCols
includeNil := useAllCols includeNil := useAllCols
lColName := strings.ToLower(col.Name) lColName := strings.ToLower(col.Name)
if b, ok := mustColumnMap[lColName]; ok { if b, ok := mustColumnMap[lColName]; ok {
if b { if b {
requiredField = true requiredField = true
@ -320,24 +329,38 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{},
continue continue
} }
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
} else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = nulType.Value()
} else { } else {
engine.autoMapType(fieldValue) if !col.SQLType.IsJson() {
if table, ok := engine.Tables[fieldValue.Type()]; ok { engine.autoMapType(fieldValue)
if len(table.PrimaryKeys) == 1 { if table, ok := engine.Tables[fieldValue.Type()]; ok {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) if len(table.PrimaryKeys) == 1 {
// fix non-int pk issues pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
//if pkField.Int() != 0 { // fix non-int pk issues
if pkField.IsValid() && !isZero(pkField.Interface()) { //if pkField.Int() != 0 {
val = pkField.Interface() if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else { } else {
continue //TODO: how to handler?
panic("not supported")
} }
} else { } else {
//TODO: how to handler? val = fieldValue.Interface()
panic("not supported")
} }
} else { } else {
val = fieldValue.Interface() bytes, err := json.Marshal(fieldValue.Interface())
if err != nil {
panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
}
if col.SQLType.IsText() {
val = string(bytes)
} else if col.SQLType.IsBlob() {
val = bytes
}
} }
} }
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
@ -413,6 +436,9 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
continue continue
} }
if col.SQLType.IsJson() {
continue
}
var colName string var colName string
if addedTableName { if addedTableName {
@ -509,24 +535,49 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{},
val = engine.FormatTime(col.SQLType.Name, t) val = engine.FormatTime(col.SQLType.Name, t)
} else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok {
continue continue
} else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok {
val, _ = valNul.Value()
if val == nil {
continue
}
} else { } else {
engine.autoMapType(fieldValue) if col.SQLType.IsJson() {
if table, ok := engine.Tables[fieldValue.Type()]; ok { if col.SQLType.IsText() {
if len(table.PrimaryKeys) == 1 { bytes, err := json.Marshal(fieldValue.Interface())
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) if err != nil {
// fix non-int pk issues engine.LogError(err)
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue continue
} }
} else { val = string(bytes)
//TODO: how to handler? } else if col.SQLType.IsBlob() {
panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) var bytes []byte
var err error
bytes, err = json.Marshal(fieldValue.Interface())
if err != nil {
engine.LogError(err)
continue
}
val = bytes
} }
} else { } else {
val = fieldValue.Interface() engine.autoMapType(fieldValue)
if table, ok := engine.Tables[fieldValue.Type()]; ok {
if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues
//if pkField.Int() != 0 {
if pkField.IsValid() && !isZero(pkField.Interface()) {
val = pkField.Interface()
} else {
continue
}
} else {
//TODO: how to handler?
panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys))
}
} else {
val = fieldValue.Interface()
}
} }
} }
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
@ -693,12 +744,17 @@ func (statement *Statement) genInSql() (string, []interface{}) {
return "", []interface{}{} return "", []interface{}{}
} }
inStrs := make([]string, 0, len(statement.inColumns)) inStrs := make([]string, len(statement.inColumns), len(statement.inColumns))
args := make([]interface{}, 0) args := make([]interface{}, 0)
var buf bytes.Buffer
var i int
for _, params := range statement.inColumns { for _, params := range statement.inColumns {
inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", buf.Reset()
fmt.Fprintf(&buf, "(%v IN (%v))",
statement.Engine.autoQuote(params.colName), statement.Engine.autoQuote(params.colName),
strings.Join(makeArray("?", len(params.args)), ","))) strings.Join(makeArray("?", len(params.args)), ","))
inStrs[i] = buf.String()
i++
args = append(args, params.args...) args = append(args, params.args...)
} }
@ -711,7 +767,7 @@ func (statement *Statement) genInSql() (string, []interface{}) {
func (statement *Statement) attachInSql() { func (statement *Statement) attachInSql() {
inSql, inArgs := statement.genInSql() inSql, inArgs := statement.genInSql()
if len(inSql) > 0 { if len(inSql) > 0 {
if statement.ConditionStr != "" { if len(statement.ConditionStr) > 0 {
statement.ConditionStr += " " + statement.Engine.dialect.AndStr() + " " statement.ConditionStr += " " + statement.Engine.dialect.AndStr() + " "
} }
statement.ConditionStr += inSql statement.ConditionStr += inSql
@ -770,6 +826,12 @@ func (statement *Statement) Distinct(columns ...string) *Statement {
return statement return statement
} }
// Generate "SELECT ... FOR UPDATE" statment
func (statement *Statement) ForUpdate() *Statement {
statement.IsForUpdate = true
return statement
}
// replace select // replace select
func (s *Statement) Select(str string) *Statement { func (s *Statement) Select(str string) *Statement {
s.selectStr = str s.selectStr = str
@ -794,6 +856,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
if strings.Contains(statement.ColumnStr, ".") { if strings.Contains(statement.ColumnStr, ".") {
statement.ColumnStr = strings.Replace(statement.ColumnStr, ".", statement.Engine.Quote("."), -1) statement.ColumnStr = strings.Replace(statement.ColumnStr, ".", statement.Engine.Quote("."), -1)
} }
statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.Quote("*"), "*", -1)
return statement return statement
} }
@ -812,15 +875,6 @@ func (statement *Statement) MustCols(columns ...string) *Statement {
return statement return statement
} }
// Update use only: not update columns
/*func (statement *Statement) NotCols(columns ...string) *Statement {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.mustColumnMap[strings.ToLower(nc)] = false
}
return statement
}*/
// indicates that use bool fields as update contents and query contiditions // indicates that use bool fields as update contents and query contiditions
func (statement *Statement) UseBool(columns ...string) *Statement { func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 { if len(columns) > 0 {
@ -865,7 +919,7 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
// Generate "Order By order" statement // Generate "Order By order" statement
func (statement *Statement) OrderBy(order string) *Statement { func (statement *Statement) OrderBy(order string) *Statement {
if statement.OrderStr != "" { if len(statement.OrderStr) > 0 {
statement.OrderStr += ", " statement.OrderStr += ", "
} }
statement.OrderStr += order statement.OrderStr += order
@ -873,44 +927,51 @@ func (statement *Statement) OrderBy(order string) *Statement {
} }
func (statement *Statement) Desc(colNames ...string) *Statement { func (statement *Statement) Desc(colNames ...string) *Statement {
if statement.OrderStr != "" { var buf bytes.Buffer
statement.OrderStr += ", " fmt.Fprintf(&buf, statement.OrderStr)
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, ", ")
} }
newColNames := statement.col2NewColsWithQuote(colNames...) newColNames := statement.col2NewColsWithQuote(colNames...)
sqlStr := strings.Join(newColNames, " DESC, ") fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
statement.OrderStr += sqlStr + " DESC" statement.OrderStr = buf.String()
return statement return statement
} }
// Method Asc provide asc order by query condition, the input parameters are columns. // Method Asc provide asc order by query condition, the input parameters are columns.
func (statement *Statement) Asc(colNames ...string) *Statement { func (statement *Statement) Asc(colNames ...string) *Statement {
if statement.OrderStr != "" { var buf bytes.Buffer
statement.OrderStr += ", " fmt.Fprintf(&buf, statement.OrderStr)
if len(statement.OrderStr) > 0 {
fmt.Fprint(&buf, ", ")
} }
newColNames := statement.col2NewColsWithQuote(colNames...) newColNames := statement.col2NewColsWithQuote(colNames...)
sqlStr := strings.Join(newColNames, " ASC, ") fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
statement.OrderStr += sqlStr + " ASC" statement.OrderStr = buf.String()
return statement return statement
} }
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(join_operator string, tablename interface{}, condition string) *Statement { func (statement *Statement) Join(join_operator string, tablename interface{}, condition string) *Statement {
var joinTable string var buf bytes.Buffer
if len(statement.JoinStr) > 0 {
fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, join_operator)
} else {
fmt.Fprintf(&buf, "%v JOIN ", join_operator)
}
switch tablename.(type) { switch tablename.(type) {
case []string: case []string:
t := tablename.([]string) t := tablename.([]string)
l := len(t) if len(t) > 1 {
if l > 1 { fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
table := t[0] } else if len(t) == 1 {
joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(t[1]) fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
} else if l == 1 {
table := t[0]
joinTable = statement.Engine.Quote(table)
} }
case []interface{}: case []interface{}:
t := tablename.([]interface{}) t := tablename.([]interface{})
l := len(t) l := len(t)
table := "" var table string
if l > 0 { if l > 0 {
f := t[0] f := t[0]
v := rValue(f) v := rValue(f)
@ -923,21 +984,17 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co
} }
} }
if l > 1 { if l > 1 {
joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(fmt.Sprintf("%v", t[1])) fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 { } else if l == 1 {
joinTable = statement.Engine.Quote(table) fmt.Fprintf(&buf, statement.Engine.Quote(table))
} }
default: default:
t := fmt.Sprintf("%v", tablename) fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
joinTable = statement.Engine.Quote(t)
}
if statement.JoinStr != "" {
statement.JoinStr = statement.JoinStr + fmt.Sprintf(" %v JOIN %v ON %v", join_operator,
joinTable, condition)
} else {
statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator,
joinTable, condition)
} }
fmt.Fprintf(&buf, " ON %v", condition)
statement.JoinStr = buf.String()
return statement return statement
} }
@ -1054,11 +1111,6 @@ func (s *Statement) genDelIndexSQL() []string {
return sqls return sqls
} }
/*
func (s *Statement) genDropSQL() string {
return s.Engine.dialect.MustDropTa(s.TableName()) + ";"
}*/
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
var table *core.Table var table *core.Table
if statement.RefTable == nil { if statement.RefTable == nil {
@ -1148,37 +1200,35 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}
} }
func (statement *Statement) genSelectSql(columnStr string) (a string) { func (statement *Statement) genSelectSql(columnStr string) (a string) {
/*if statement.GroupByStr != "" {
if columnStr == "" {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
}
//statement.GroupByStr = columnStr
}*/
var distinct string var distinct string
if statement.IsDistinct { if statement.IsDistinct {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
var dialect core.Dialect = statement.Engine.Dialect()
var top string var top string
var mssqlCondi string var mssqlCondi string
/*var orderBy string
if statement.OrderStr != "" {
orderBy = fmt.Sprintf(" ORDER BY %v", statement.OrderStr)
}*/
statement.processIdParam() statement.processIdParam()
var whereStr string
if statement.WhereStr != "" { var buf bytes.Buffer
whereStr = fmt.Sprintf(" WHERE %v", statement.WhereStr) if len(statement.WhereStr) > 0 {
if statement.ConditionStr != "" { if len(statement.ConditionStr) > 0 {
whereStr = fmt.Sprintf("%v %s %v", whereStr, statement.Engine.Dialect().AndStr(), fmt.Fprintf(&buf, " WHERE (%v)", statement.WhereStr)
statement.ConditionStr) } else {
fmt.Fprintf(&buf, " WHERE %v", statement.WhereStr)
} }
} else if statement.ConditionStr != "" { if statement.ConditionStr != "" {
whereStr = fmt.Sprintf(" WHERE %v", statement.ConditionStr) fmt.Fprintf(&buf, " %s (%v)", dialect.AndStr(), statement.ConditionStr)
}
} else if len(statement.ConditionStr) > 0 {
fmt.Fprintf(&buf, " WHERE %v", statement.ConditionStr)
} }
var whereStr = buf.String()
var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName())
if statement.TableAlias != "" { if statement.TableAlias != "" {
if statement.Engine.dialect.DBType() == core.ORACLE { if dialect.DBType() == core.ORACLE {
fromStr += " " + statement.Engine.Quote(statement.TableAlias) fromStr += " " + statement.Engine.Quote(statement.TableAlias)
} else { } else {
fromStr += " AS " + statement.Engine.Quote(statement.TableAlias) fromStr += " AS " + statement.Engine.Quote(statement.TableAlias)
@ -1188,7 +1238,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
} }
if statement.Engine.dialect.DBType() == core.MSSQL { if dialect.DBType() == core.MSSQL {
if statement.LimitN > 0 { if statement.LimitN > 0 {
top = fmt.Sprintf(" TOP %d ", statement.LimitN) top = fmt.Sprintf(" TOP %d ", statement.LimitN)
} }
@ -1219,10 +1269,9 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
} }
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr)
fromStr, whereStr) if len(mssqlCondi) > 0 {
if mssqlCondi != "" { if len(whereStr) > 0 {
if whereStr != "" {
a += " AND " + mssqlCondi a += " AND " + mssqlCondi
} else { } else {
a += " WHERE " + mssqlCondi a += " WHERE " + mssqlCondi
@ -1238,17 +1287,20 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" { if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE { if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 { } else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
} }
} else if statement.Engine.dialect.DBType() == core.ORACLE { } else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 { if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
} }
} }
if statement.IsForUpdate {
a = dialect.ForUpdateSql(a)
}
return return
} }

View File

@ -17,7 +17,7 @@ import (
) )
const ( const (
Version string = "0.4.3.0526" Version string = "0.4.4.1029"
) )
func regDrvsNDialects() bool { func regDrvsNDialects() bool {
@ -39,7 +39,7 @@ func regDrvsNDialects() bool {
for driverName, v := range providedDrvsNDialects { for driverName, v := range providedDrvsNDialects {
if driver := core.QueryDriver(driverName); driver == nil { if driver := core.QueryDriver(driverName); driver == nil {
core.RegisterDriver(driverName, v.getDriver()) core.RegisterDriver(driverName, v.getDriver())
core.RegisterDialect(v.dbType, v.getDialect()) core.RegisterDialect(v.dbType, v.getDialect)
} }
} }
return true return true