Merge branch 'master' into master

This commit is contained in:
Lunny Xiao 2018-03-07 17:19:48 +08:00 committed by GitHub
commit 99cc3e5b65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 460 additions and 72 deletions

2
.gitignore vendored
View File

@ -28,3 +28,5 @@ temp_test.go
.vscode
xorm.test
*.sqlite3
.idea/

View File

@ -217,7 +217,7 @@ has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
```Go
var users []User
err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users)
// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10
// SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0
type Detail struct {
Id int64
@ -234,7 +234,7 @@ err := engine.Table("user").Select("user.*, detail.*").
Join("INNER", "detail", "detail.user_id = user.id").
Where("user.name = ?", name).Limit(10, 0).
Find(&users)
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0
```
* `Iterate` and `Rows` query multiple records and record by record handle, there are two methods Iterate and Rows
@ -265,7 +265,7 @@ for rows.Next() {
* `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on.
```Go
affected, err := engine.Id(1).Update(&user)
affected, err := engine.ID(1).Update(&user)
// UPDATE user SET ... Where id = ?
affected, err := engine.Update(&user, &User{Name:name})
@ -276,14 +276,14 @@ affected, err := engine.In("id", ids).Update(&user)
// UPDATE user SET ... Where id IN (?, ?, ?)
// force update indicated columns by Cols
affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
// force NOT update indicated columns by Omit
affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
affected, err := engine.Id(1).AllCols().Update(&user)
affected, err := engine.ID(1).AllCols().Update(&user)
// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ?
```

View File

@ -224,7 +224,7 @@ has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist()
```Go
var users []User
err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users)
// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10
// SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0
type Detail struct {
Id int64
@ -241,7 +241,7 @@ err := engine.Table("user").Select("user.*, detail.*")
Join("INNER", "detail", "detail.user_id = user.id").
Where("user.name = ?", name).Limit(10, 0).
Find(&users)
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10
// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0
```
* `Iterate``Rows` 根据条件遍历数据库,可以有两种方式: Iterate and Rows
@ -272,7 +272,7 @@ for rows.Next() {
* `Update` 更新数据除非使用Cols,AllCols函数指明默认只更新非空和非0的字段
```Go
affected, err := engine.Id(1).Update(&user)
affected, err := engine.ID(1).Update(&user)
// UPDATE user SET ... Where id = ?
affected, err := engine.Update(&user, &User{Name:name})
@ -283,14 +283,14 @@ affected, err := engine.In(ids).Update(&user)
// UPDATE user SET ... Where id IN (?, ?, ?)
// force update indicated columns by Cols
affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
// force NOT update indicated columns by Omit
affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12})
affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12})
// UPDATE user SET age = ?, updated=? Where id = ?
affected, err := engine.Id(1).AllCols().Update(&user)
affected, err := engine.ID(1).AllCols().Update(&user)
// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ?
```

View File

@ -764,6 +764,9 @@ var (
"YES": true,
"ZONE": true,
}
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
)
type postgres struct {
@ -771,7 +774,14 @@ type postgres struct {
}
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(d, db, uri, drivername, dataSourceName)
err := db.Base.Init(d, db, uri, drivername, dataSourceName)
if err != nil {
return err
}
if db.Schema == "" {
db.Schema = DefaultPostgresSchema
}
return nil
}
func (db *postgres) SqlType(c *core.Column) string {
@ -868,29 +878,35 @@ func (db *postgres) IndexOnTable() bool {
}
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName}
if len(db.Schema) == 0 {
args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args
}
args := []interface{}{db.Schema, tableName, idxName}
return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args
}
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
if len(db.Schema) == 0 {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}
args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args
}*/
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
if len(db.Schema) == 0 {
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
tableName, col.Name, db.SqlType(col))
}
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s",
db.Schema, tableName, col.Name, db.SqlType(col))
}
func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
//var unique string
quote := db.Quote
idxName := index.Name
@ -906,9 +922,14 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
}
func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
args := []interface{}{tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
args := []interface{}{db.Schema, tableName, colName}
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" +
" AND column_name = $3"
if len(db.Schema) == 0 {
args = []interface{}{tableName, colName}
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" +
" AND column_name = $2"
}
db.LogSQL(query, args)
rows, err := db.DB().Query(query, args...)
@ -921,8 +942,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
}
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
// FIXME: the schema should be replaced by user custom's
args := []interface{}{tableName, "public"}
args := []interface{}{tableName}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -933,7 +953,15 @@ FROM pg_attribute f
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name
WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;`
WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;`
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
f = "AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@ -1023,9 +1051,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att
}
func (db *postgres) GetTables() ([]*core.Table, error) {
// FIXME: replace public to user customrize schema
args := []interface{}{"public"}
s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1")
args := []interface{}{}
s := "SELECT tablename FROM pg_tables"
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " WHERE schemaname = $1"
}
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
@ -1049,10 +1081,13 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
}
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
// FIXME: replace the public schema to user specify schema
args := []interface{}{"public", tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2")
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
db.LogSQL(s, args)
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " AND schemaname=$2"
}
rows, err := db.DB().Query(s, args...)
if err != nil {

View File

@ -7,11 +7,7 @@ import (
"github.com/go-xorm/core"
)
func TestPostgresDialect(t *testing.T) {
TestParse(t)
}
func TestParse(t *testing.T) {
func TestParsePostgres(t *testing.T) {
tests := []struct {
in string
expected string
@ -20,10 +16,10 @@ func TestParse(t *testing.T) {
{"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true},
{"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true},
{"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false},
{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
//{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
//{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
{"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true},
{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
//{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
{"dbname=db sslmode=disable", "db", true},
{"user=auser password=password dbname=db sslmode=disable", "db", true},
{"", "db", false},

View File

@ -1453,6 +1453,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
return session.Find(beans, condiBeans...)
}
// FindAndCount find the results and also return the counts
func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
session := engine.NewSession()
defer session.Close()
return session.FindAndCount(rowsSlicePtr, condiBean...)
}
// Iterate record by record handle records from table, bean's non-empty fields
// are conditions.
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error {

View File

@ -30,6 +30,7 @@ type Interface interface {
Exec(string, ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error
FindAndCount(interface{}, ...interface{}) (int64, error)
Get(interface{}) (bool, error)
GroupBy(keys string) *Session
ID(interface{}) *Session

View File

@ -122,7 +122,10 @@ func (m *Migrate) RollbackLast() error {
func (m *Migrate) getLastRunnedMigration() (*Migration, error) {
for i := len(m.migrations) - 1; i >= 0; i-- {
migration := m.migrations[i]
if m.migrationDidRun(migration) {
run, err := m.migrationDidRun(migration)
if err != nil {
return nil, err
} else if run {
return migration, nil
}
}
@ -165,7 +168,12 @@ func (m *Migrate) runMigration(migration *Migration) error {
return ErrMissingID
}
if !m.migrationDidRun(migration) {
run, err :=m.migrationDidRun(migration)
if err != nil {
return err
}
if !run {
if err := migration.Migrate(m.db); err != nil {
return err
}
@ -193,11 +201,9 @@ func (m *Migrate) createMigrationTableIfNotExists() error {
return nil
}
func (m *Migrate) migrationDidRun(mig *Migration) bool {
row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID)
var count int
row.Scan(&count)
return count > 0
func (m *Migrate) migrationDidRun(mig *Migration) (bool, error) {
count, err := m.db.SQL(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID).Count()
return count > 0, err
}
func (m *Migrate) isFirstRun() bool {

View File

@ -10,6 +10,7 @@ import (
"reflect"
"github.com/go-xorm/builder"
"github.com/go-xorm/core"
)
// Exist returns true if the record exist otherwise return false
@ -35,10 +36,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) {
return false, err
}
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s WHERE %s", tableName, condSQL)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL)
}
args = condArgs
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
if session.engine.dialect.DBType() == core.MSSQL {
sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s", tableName)
} else {
sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName)
}
args = []interface{}{}
}
} else {

View File

@ -29,6 +29,36 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return session.find(rowsSlicePtr, condiBean...)
}
// FindAndCount find the results and also return the counts
func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
err := session.find(rowsSlicePtr, condiBean...)
if err != nil {
return 0, err
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
return 0, errors.New("needs a pointer to a slice or a map")
}
sliceElementType := sliceValue.Type().Elem()
if sliceElementType.Kind() == reflect.Ptr {
sliceElementType = sliceElementType.Elem()
}
session.autoResetStatement = true
if session.statement.selectStr != "" {
session.statement.selectStr = ""
}
return session.Count(reflect.New(sliceElementType).Interface())
}
func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map {
@ -128,7 +158,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
}
args = append(session.statement.joinArgs, condArgs...)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL)
sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true)
if err != nil {
return err
}

View File

@ -488,3 +488,84 @@ func TestFindBit(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 2, len(results))
}
func TestFindMark(t *testing.T) {
type Mark struct {
Mark1 string `xorm:"VARCHAR(1)"`
Mark2 string `xorm:"VARCHAR(1)"`
MarkA string `xorm:"VARCHAR(1)"`
}
assert.NoError(t, prepareEngine())
assertSync(t, new(Mark))
cnt, err := testEngine.Insert([]Mark{
{
Mark1: "1",
Mark2: "2",
MarkA: "A",
},
{
Mark1: "1",
Mark2: "2",
MarkA: "A",
},
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
var results = make([]Mark, 0, 2)
err = testEngine.Find(&results)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(results))
}
func TestFindAndCountOneFunc(t *testing.T) {
type FindAndCountStruct struct {
Id int64
Content string
Msg bool `xorm:"bit"`
}
assert.NoError(t, prepareEngine())
assertSync(t, new(FindAndCountStruct))
cnt, err := testEngine.Insert([]FindAndCountStruct{
{
Content: "111",
Msg: false,
},
{
Content: "222",
Msg: true,
},
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
var results = make([]FindAndCountStruct, 0, 2)
cnt, err = testEngine.FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(results))
assert.EqualValues(t, 2, cnt)
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
results = make([]FindAndCountStruct, 0, 1)
cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg").
Limit(1).FindAndCount(&results)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
assert.EqualValues(t, 1, cnt)
}

View File

@ -5,6 +5,7 @@
package xorm
import (
"database/sql"
"errors"
"reflect"
"strconv"
@ -79,6 +80,13 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea
return false, nil
}
switch bean.(type) {
case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString:
return true, rows.Scan(&bean)
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString:
return true, rows.Scan(bean)
}
switch beanKind {
case reflect.Struct:
fields, err := rows.Columns()

View File

@ -5,6 +5,7 @@
package xorm
import (
"database/sql"
"fmt"
"testing"
"time"
@ -31,7 +32,7 @@ func TestGetVar(t *testing.T) {
Age: 28,
Money: 1.5,
}
_, err := testEngine.InsertOne(data)
_, err := testEngine.InsertOne(&data)
assert.NoError(t, err)
var msg string
@ -55,6 +56,27 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, true, has)
assert.EqualValues(t, 28, age2)
var id sql.NullInt64
has, err = testEngine.Table("get_var").Cols("id").Get(&id)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, true, id.Valid)
assert.EqualValues(t, data.Id, id.Int64)
var msgNull sql.NullString
has, err = testEngine.Table("get_var").Cols("msg").Get(&msgNull)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, true, msgNull.Valid)
assert.EqualValues(t, data.Msg, msgNull.String)
var nullMoney sql.NullFloat64
has, err = testEngine.Table("get_var").Cols("money").Get(&nullMoney)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, true, nullMoney.Valid)
assert.EqualValues(t, data.Money, nullMoney.Float64)
var money float64
has, err = testEngine.Table("get_var").Cols("money").Get(&money)
assert.NoError(t, err)

View File

@ -662,3 +662,81 @@ func TestInsertCreatedInt64(t *testing.T) {
assert.EqualValues(t, data.Created, data2.Created)
}
type MyUserinfo Userinfo
func (MyUserinfo) TableName() string {
return "user_info"
}
func TestInsertMulti3(t *testing.T) {
assert.NoError(t, prepareEngine())
testEngine.ShowSQL(true)
assertSync(t, new(MyUserinfo))
users := []MyUserinfo{
{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()},
{Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
cnt, err := testEngine.Insert(&users)
assert.NoError(t, err)
assert.EqualValues(t, len(users), cnt)
users2 := []*MyUserinfo{
&MyUserinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
&MyUserinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()},
&MyUserinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
&MyUserinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
cnt, err = testEngine.Insert(&users2)
assert.NoError(t, err)
assert.EqualValues(t, len(users), cnt)
}
type MyUserinfo2 struct {
Uid int64 `xorm:"id pk not null autoincr"`
Username string `xorm:"unique"`
Departname string
Alias string `xorm:"-"`
Created time.Time
Detail Userdetail `xorm:"detail_id int(11)"`
Height float64
Avatar []byte
IsMan bool
}
func (MyUserinfo2) TableName() string {
return "user_info"
}
func TestInsertMulti4(t *testing.T) {
assert.NoError(t, prepareEngine())
testEngine.ShowSQL(true)
assertSync(t, new(MyUserinfo2))
users := []MyUserinfo2{
{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()},
{Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
{Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
cnt, err := testEngine.Insert(&users)
assert.NoError(t, err)
assert.EqualValues(t, len(users), cnt)
users2 := []*MyUserinfo2{
&MyUserinfo2{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
&MyUserinfo2{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()},
&MyUserinfo2{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
&MyUserinfo2{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
cnt, err = testEngine.Insert(&users2)
assert.NoError(t, err)
assert.EqualValues(t, len(users), cnt)
}

View File

@ -17,7 +17,17 @@ import (
func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) {
if len(sqlorArgs) > 0 {
return sqlorArgs[0].(string), sqlorArgs[1:], nil
switch sqlorArgs[0].(type) {
case string:
return sqlorArgs[0].(string), sqlorArgs[1:], nil
case *builder.Builder:
return sqlorArgs[0].(*builder.Builder).ToSQL()
case builder.Builder:
bd := sqlorArgs[0].(builder.Builder)
return bd.ToSQL()
default:
return "", nil, ErrUnSupportedType
}
}
if session.statement.RawSQL != "" {
@ -60,7 +70,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
}
args := append(session.statement.joinArgs, condArgs...)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL)
sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true)
if err != nil {
return "", nil, err
}

View File

@ -10,6 +10,8 @@ import (
"testing"
"time"
"github.com/go-xorm/builder"
"github.com/stretchr/testify/assert"
)
@ -183,3 +185,48 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, err)
assertResult(t, results)
}
func TestQueryWithBuilder(t *testing.T) {
assert.NoError(t, prepareEngine())
type QueryWithBuilder struct {
Id int64 `xorm:"autoincr pk"`
Msg string `xorm:"varchar(255)"`
Age int
Money float32
Created time.Time `xorm:"created"`
}
testEngine.ShowSQL(true)
assert.NoError(t, testEngine.Sync2(new(QueryWithBuilder)))
var q = QueryWithBuilder{
Msg: "message",
Age: 20,
Money: 3000,
}
cnt, err := testEngine.Insert(&q)
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
assertResult := func(t *testing.T, results []map[string][]byte) {
assert.EqualValues(t, 1, len(results))
id, err := strconv.ParseInt(string(results[0]["id"]), 10, 64)
assert.NoError(t, err)
assert.EqualValues(t, 1, id)
assert.Equal(t, "message", string(results[0]["msg"]))
age, err := strconv.Atoi(string(results[0]["age"]))
assert.NoError(t, err)
assert.EqualValues(t, 20, age)
money, err := strconv.ParseFloat(string(results[0]["money"]), 32)
assert.NoError(t, err)
assert.EqualValues(t, 3000, money)
}
results, err := testEngine.Query(builder.Select("*").From("query_with_builder"))
assert.NoError(t, err)
assertResult(t, results)
}

View File

@ -255,6 +255,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
var structTables []*core.Table
for _, bean := range beans {

View File

@ -217,3 +217,51 @@ func TestCharst(t *testing.T) {
panic(err)
}
}
func TestSync2_1(t *testing.T) {
type WxTest struct {
Id int `xorm:"not null pk autoincr INT(64)`
Passport_user_type int16 `xorm:"null int"`
Id_delete int8 `xorm:"null int default 1"`
}
assert.NoError(t, prepareEngine())
assert.NoError(t, testEngine.DropTables("wx_test"))
assert.NoError(t, testEngine.Sync2(new(WxTest)))
assert.NoError(t, testEngine.Sync2(new(WxTest)))
}
func TestUnique_1(t *testing.T) {
type UserUnique struct {
Id int64
UserName string `xorm:"unique varchar(25) not null"`
Password string `xorm:"varchar(255) not null"`
Admin bool `xorm:"not null"`
CreatedAt time.Time `xorm:"created"`
UpdatedAt time.Time `xorm:"updated"`
}
assert.NoError(t, prepareEngine())
assert.NoError(t, testEngine.DropTables("user_unique"))
assert.NoError(t, testEngine.Sync2(new(UserUnique)))
assert.NoError(t, testEngine.DropTables("user_unique"))
assert.NoError(t, testEngine.CreateTables(new(UserUnique)))
assert.NoError(t, testEngine.CreateUniques(new(UserUnique)))
}
func TestSync2_2(t *testing.T) {
type TestSync2Index struct {
Id int64
UserId int64 `xorm:"index"`
}
assert.NoError(t, prepareEngine())
for i := 0; i < 10; i++ {
tableName := fmt.Sprintf("test_sync2_index_%d", i)
assert.NoError(t, testEngine.Table(tableName).Sync2(new(TestSync2Index)))
}
}

View File

@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
return "", nil, err
}
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true)
if err != nil {
return "", nil, err
}
@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)"
}
}
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false)
if err != nil {
return "", nil, err
}
@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err
}
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true)
if err != nil {
return "", nil, err
}
@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil
}
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) {
var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT "
@ -1149,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start)
}
}
}
if statement.IsForUpdate {