fix postgres with schema

This commit is contained in:
Lunny Xiao 2018-04-09 14:40:30 +08:00
parent 707e65ee77
commit faa4602c0c
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
14 changed files with 275 additions and 492 deletions

View File

@ -895,6 +895,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
args := []interface{}{db.Schema, tableName} args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
} }
@ -960,7 +961,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
var f string var f string
if len(db.Schema) != 0 { if len(db.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.Schema)
f = "AND s.table_schema = $2" f = " AND s.table_schema = $2"
} }
s = fmt.Sprintf(s, f) s = fmt.Sprintf(s, f)
@ -1085,11 +1086,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
db.LogSQL(s, args)
if len(db.Schema) != 0 { if len(db.Schema) != 0 {
args = append(args, db.Schema) args = append(args, db.Schema)
s = s + " AND schemaname=$2" s = s + " AND schemaname=$2"
} }
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...) rows, err := db.DB().Query(s, args...)
if err != nil { if err != nil {

View File

@ -536,46 +536,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return nil return nil
} }
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
v := rValue(beanOrTableName)
if v.Type().Kind() == reflect.String {
return beanOrTableName.(string), nil
} else if v.Type().Kind() == reflect.Struct {
return engine.tbName(v), nil
}
return "", errors.New("bean should be a struct or struct's point")
}
func (engine *Engine) tbSchemaName(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
func (engine *Engine) tbName(v reflect.Value) string {
if tb, ok := v.Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
if v.Type().Kind() == reflect.Ptr {
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
} else if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
}
return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()))
}
// Cascade use cascade or not // Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession() session := engine.NewSession()
@ -895,20 +855,8 @@ var (
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type() t := v.Type()
table := engine.newTable() table := engine.newTable()
if tb, ok := v.Interface().(TableName); ok {
table.Name = tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
table.Name = tb.TableName()
}
}
if table.Name == "" {
table.Name = engine.TableMapper.Obj2Table(t.Name())
}
}
table.Type = t table.Type = t
table.Name = engine.tbNameForMap(v)
var idFieldColName string var idFieldColName string
var hasCacheTag, hasNoCacheTag bool var hasCacheTag, hasNoCacheTag bool
@ -1237,13 +1185,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans { for _, bean := range beans {
v := rValue(bean) v := rValue(bean)
tableName := engine.tbName(v) tableNameNoSchema := engine.tbNameNoSchemaString(v.Interface())
table, err := engine.autoMapType(v) table, err := engine.autoMapType(v)
if err != nil { if err != nil {
return err return err
} }
isExist, err := session.Table(bean).isTableExist(tableName) isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
if err != nil { if err != nil {
return err return err
} }
@ -1269,7 +1217,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
} else { } else {
for _, col := range table.Columns() { for _, col := range table.Columns() {
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
if err != nil { if err != nil {
return err return err
} }
@ -1289,7 +1237,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
isExist, err := session.isIndexExist2(tableName, index.Cols, true) isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil { if err != nil {
return err return err
} }
@ -1298,13 +1246,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
err = session.addUnique(tableName, name) err = session.addUnique(tableNameNoSchema, name)
if err != nil { if err != nil {
return err return err
} }
} }
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(tableName, index.Cols, false) isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil { if err != nil {
return err return err
} }
@ -1313,7 +1261,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
err = session.addIndex(tableName, name) err = session.addIndex(tableNameNoSchema, name)
if err != nil { if err != nil {
return err return err
} }

103
engine_table.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
"github.com/go-xorm/core"
)
// TableNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) TableNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
func (engine *Engine) tbName(v reflect.Value) string {
return engine.TableNameWithSchema(engine.tbNameNoSchemaString(v.Interface()))
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {
t := v.Type()
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
return tb.TableName()
}
}
}
return engine.TableMapper.Obj2Table(t.Name())
}
func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(w, engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name()))
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f)))
}
}
}
if l > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(w, engine.Quote(table))
}
case TableName:
fmt.Fprintf(w, tablename.(TableName).TableName())
case string:
fmt.Fprintf(w, tablename.(string))
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name()))
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename)))
}
}
}
func (engine *Engine) tbNameNoSchemaString(tablename interface{}) string {
var buf bytes.Buffer
engine.tbNameNoSchema(&buf, tablename)
return buf.String()
}

View File

@ -95,6 +95,7 @@ type EngineInterface interface {
Sync2(...interface{}) error Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table TableInfo(bean interface{}) *Table
TableNameWithSchema(string) string
UnMapType(reflect.Type) UnMapType(reflect.Type)
} }

View File

@ -54,7 +54,11 @@ func TestRows(t *testing.T) {
} }
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows)) var tbName = testEngine.Quote("user_rows")
if testEngine.Dialect().URI().Schema != "" {
tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName
}
rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
assert.NoError(t, err) assert.NoError(t, err)
defer rows2.Close() defer rows2.Close()

View File

@ -122,18 +122,11 @@ func TestIn(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 3, cnt) assert.EqualValues(t, 3, cnt)
department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
var usrs []Userinfo var usrs []Userinfo
err = testEngine.Limit(3).Find(&usrs) err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs)
if err != nil { assert.Error(t, err)
t.Error(err) assert.EqualValues(t, 3, len(usrs))
panic(err)
}
if len(usrs) != 3 {
err = errors.New("there are not 3 records")
t.Error(err)
panic(err)
}
var ids []int64 var ids []int64
var idsStr string var idsStr string
@ -145,35 +138,20 @@ func TestIn(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
if len(users) != 3 { assert.EqualValues(t, 3, len(users))
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
users = make([]Userinfo, 0) users = make([]Userinfo, 0)
err = testEngine.In("(id)", ids).Find(&users) err = testEngine.In("(id)", ids).Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
if len(users) != 3 { assert.EqualValues(t, 3, len(users))
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
for _, user := range users { for _, user := range users {
if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
err = errors.New("in uses should be " + idsStr + " total 3") err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err) assert.NoError(t, err)
panic(err)
} }
} }
@ -183,87 +161,41 @@ func TestIn(t *testing.T) {
idsInterface = append(idsInterface, id) idsInterface = append(idsInterface, id)
} }
department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
assert.EqualValues(t, 3, len(users))
if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
for _, user := range users { for _, user := range users {
if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
err = errors.New("in uses should be " + idsStr + " total 3") err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err) assert.NoError(t, err)
panic(err)
} }
} }
dev := testEngine.GetColumnMapper().Obj2Table("Dev") dev := testEngine.GetColumnMapper().Obj2Table("Dev")
err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users)
assert.NoError(t, err)
if err != nil {
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 1, cnt)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
user := new(Userinfo) user := new(Userinfo)
has, err := testEngine.ID(ids[0]).Get(user) has, err := testEngine.ID(ids[0]).Get(user)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.True(t, has)
panic(err) assert.EqualValues(t, "dev-", user.Departname)
}
if !has {
err = errors.New("get record not 1")
t.Error(err)
panic(err)
}
if user.Departname != "dev-" {
err = errors.New("update not success")
t.Error(err)
panic(err)
}
cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 1, cnt)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{})
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 1, cnt)
panic(err)
}
if cnt != 1 {
err = errors.New("deleted records not 1")
t.Error(err)
panic(err)
}
} }
func TestFindAndCount(t *testing.T) { func TestFindAndCount(t *testing.T) {

View File

@ -96,21 +96,19 @@ func TestFind(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := testEngine.Find(&users) err := testEngine.Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
for _, user := range users { for _, user := range users {
fmt.Println(user) fmt.Println(user)
} }
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo")
err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2) var tbName = testEngine.Quote(userinfo)
if err != nil { if testEngine.Dialect().URI().Schema != "" {
t.Error(err) tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName
panic(err)
} }
err = testEngine.SQL("select * from " + tbName).Find(&users2)
assert.NoError(t, err)
} }
func TestFind2(t *testing.T) { func TestFind2(t *testing.T) {
@ -238,14 +236,8 @@ func TestDistinct(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
departname := testEngine.GetTableMapper().Obj2Table("Departname") departname := testEngine.GetTableMapper().Obj2Table("Departname")
err = testEngine.Distinct(departname).Find(&users) err = testEngine.Distinct(departname).Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 2, len(users))
panic(err)
}
if len(users) != 1 {
t.Error(err)
panic(errors.New("should be one record"))
}
fmt.Println(users) fmt.Println(users)
@ -255,11 +247,9 @@ func TestDistinct(t *testing.T) {
users2 := make([]Depart, 0) users2 := make([]Depart, 0)
err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2) err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
if err != nil { assert.NoError(t, err)
t.Error(err) if len(users2) != 2 {
panic(err) fmt.Println(len(users2))
}
if len(users2) != 1 {
t.Error(err) t.Error(err)
panic(errors.New("should be one record")) panic(errors.New("should be one record"))
} }
@ -272,18 +262,12 @@ func TestOrder(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := testEngine.OrderBy("id desc").Find(&users) err := testEngine.OrderBy("id desc").Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
err = testEngine.Asc("id", "username").Desc("height").Find(&users2) err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users2) fmt.Println(users2)
} }
@ -293,10 +277,7 @@ func TestHaving(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
/*users = make([]Userinfo, 0) /*users = make([]Userinfo, 0)
@ -324,18 +305,12 @@ func TestOrderSameMapper(t *testing.T) {
users := make([]Userinfo, 0) users := make([]Userinfo, 0)
err := testEngine.OrderBy("(id) desc").Find(&users) err := testEngine.OrderBy("(id) desc").Find(&users)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users) fmt.Println(users)
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
fmt.Println(users2) fmt.Println(users2)
} }

View File

@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) {
} }
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
assertSync(t, new(TaskSolution))
assert.NoError(t, testEngine.Sync2(new(TaskSolution))) tables1, err := testEngine.DBMetas()
tables, err := testEngine.DBMetas()
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
pkCols := tables[0].PKColumns() assertSync(t, new(TaskSolution))
assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
tables2, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1+len(tables1), len(tables2))
var table *core.Table
for _, t := range tables2 {
if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
table = t
break
}
}
assert.NotEqual(t, nil, table)
pkCols := table.PKColumns()
assert.EqualValues(t, 2, len(pkCols)) assert.EqualValues(t, 2, len(pkCols))
assert.EqualValues(t, "uid", pkCols[0].Name) assert.EqualValues(t, "uid", pkCols[0].Name)
assert.EqualValues(t, "tid", pkCols[1].Name) assert.EqualValues(t, "tid", pkCols[1].Name)

View File

@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) {
_, err := testEngine.InsertOne(data) _, err := testEngine.InsertOne(data)
assert.NoError(t, err) assert.NoError(t, err)
records, err := testEngine.QueryString("select * from get_var2") records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var2"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(records)) assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0])) assert.Equal(t, 5, len(records[0]))
@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) {
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
assert.NoError(t, err) assert.NoError(t, err)
records, err := testEngine.QueryString("select * from get_var3") records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var3"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(records)) assert.Equal(t, 1, len(records))
assert.Equal(t, 2, len(records[0])) assert.Equal(t, 2, len(records[0]))
@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) {
_, err := testEngine.InsertOne(data) _, err := testEngine.InsertOne(data)
assert.NoError(t, err) assert.NoError(t, err)
records, err := testEngine.QueryInterface("select * from get_var_interface") records, err := testEngine.QueryInterface("select * from " + testEngine.TableNameWithSchema("get_var_interface"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(records)) assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0])) assert.Equal(t, 5, len(records[0]))
@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assertResult(t, results) assertResult(t, results)
results, err = testEngine.SQL("select * from query_no_params").Query() results, err = testEngine.SQL("select * from " + testEngine.TableNameWithSchema("query_no_params")).Query()
assert.NoError(t, err) assert.NoError(t, err)
assertResult(t, results) assertResult(t, results)
} }
@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) {
assert.EqualValues(t, 3000, money) assert.EqualValues(t, 3000, money)
} }
results, err := testEngine.Query(builder.Select("*").From("query_with_builder")) results, err := testEngine.Query(builder.Select("*").From(testEngine.TableNameWithSchema("query_with_builder")))
assert.NoError(t, err) assert.NoError(t, err)
assertResult(t, results) assertResult(t, results)
} }

View File

@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) {
assert.NoError(t, testEngine.Sync2(new(UserinfoQuery))) assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))
res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user") res, err := testEngine.Exec("INSERT INTO "+testEngine.TableNameWithSchema("`userinfo_query`")+" (uid, name) VALUES (?, ?)", 1, "user")
assert.NoError(t, err) assert.NoError(t, err)
cnt, err := res.RowsAffected() cnt, err := res.RowsAffected()
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
results, err := testEngine.Query("select * from userinfo_query") results, err := testEngine.Query("select * from " + testEngine.TableNameWithSchema("userinfo_query"))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, len(results))
id, err := strconv.Atoi(string(results[0]["uid"])) id, err := strconv.Atoi(string(results[0]["uid"]))

View File

@ -128,14 +128,12 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
} }
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName) tableName := session.engine.tbNameNoSchemaString(beanOrTableName)
if err != nil {
return err
}
var needDrop = true var needDrop = true
if !session.engine.dialect.SupportDropIfExists() { if !session.engine.dialect.SupportDropIfExists() {
fmt.Println("TableCheckSql:", tableName)
sqlStr, args := session.engine.dialect.TableCheckSql(tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
fmt.Println("sqlStr:", sqlStr)
results, err := session.queryBytes(sqlStr, args...) results, err := session.queryBytes(sqlStr, args...)
if err != nil { if err != nil {
return err return err
@ -144,8 +142,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
} }
if needDrop { if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(tableName) sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableNameWithSchema(tableName))
_, err = session.exec(sqlStr) _, err := session.exec(sqlStr)
return err return err
} }
return nil return nil
@ -157,10 +155,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close() defer session.Close()
} }
tableName, err := session.engine.tableName(beanOrTableName) tableName := session.engine.tbNameNoSchemaString(beanOrTableName)
if err != nil {
return false, err
}
return session.isTableExist(tableName) return session.isTableExist(tableName)
} }
@ -190,7 +185,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
func (session *Session) isTableEmpty(tableName string) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64 var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.TableNameWithSchema(session.engine.Quote(tableName)))
err := session.queryRow(sqlStr).Scan(&total) err := session.queryRow(sqlStr).Scan(&total)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -270,7 +265,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err return err
} }
structTables = append(structTables, table) structTables = append(structTables, table)
var tbName = session.tbNameNoSchema(table) tbName := session.tbNameNoSchema(table)
tbNameWithSchema := engine.TableNameWithSchema(tbName)
var oriTable *core.Table var oriTable *core.Table
for _, tb := range tables { for _, tb := range tables {
@ -315,32 +311,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
if engine.dialect.DBType() == core.MYSQL || if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES { engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n", engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} else { } else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} }
} }
} else { } else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbName, col.Name, curType, expectedType) tbNameWithSchema, col.Name, curType, expectedType)
} }
} }
} else if expectedType == core.Varchar { } else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL { if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length { if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length) tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} }
} }
} }
@ -354,7 +350,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
} }
} else { } else {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbName session.statement.tableName = tbNameWithSchema
err = session.addColumn(col.Name) err = session.addColumn(col.Name)
} }
if err != nil { if err != nil {
@ -377,7 +373,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil { if oriIndex != nil {
if oriIndex.Type != index.Type { if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex) sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -393,7 +389,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes { for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok { if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2) sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
_, err = session.exec(sql) _, err = session.exec(sql)
if err != nil { if err != nil {
return err return err
@ -404,12 +400,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames { for name, index := range addedNames {
if index.Type == core.UniqueType { if index.Type == core.UniqueType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbName session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbName, name) err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType { } else if index.Type == core.IndexType {
session.statement.RefTable = table session.statement.RefTable = table
session.statement.tableName = tbName session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbName, name) err = session.addIndex(tbNameWithSchema, name)
} }
if err != nil { if err != nil {
return err return err
@ -434,7 +430,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for _, colName := range table.ColumnsSeq() { for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil { if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableNameWithSchema(table.Name), colName)
} }
} }
} }

View File

@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) {
defer session.Close() defer session.Close()
err := session.Begin() err := session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
return
}
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1) _, err = session.Insert(&user1)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
return
}
user2 := Userinfo{Username: "yyy"} user2 := Userinfo{Username: "yyy"}
_, err = session.Where("(id) = ?", 0).Update(&user2) _, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil { assert.NoError(t, err)
session.Rollback()
fmt.Println(err)
//t.Error(err)
return
}
_, err = session.Delete(&user2) _, err = session.Delete(&user2)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
return
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
return
}
// panic(err) !nashtsai! should remove this
} }
func TestCombineTransaction(t *testing.T) { func TestCombineTransaction(t *testing.T) {
@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) {
defer session.Close() defer session.Close()
err := session.Begin() err := session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1) _, err = session.Insert(&user1)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
}
user2 := Userinfo{Username: "zzz"} user2 := Userinfo{Username: "zzz"}
_, err = session.Where("id = ?", 0).Update(&user2) _, err = session.Where("id = ?", 0).Update(&user2)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
}
_, err = session.Exec("delete from userinfo where username = ?", user2.Username) _, err = session.Exec("delete from "+testEngine.TableNameWithSchema("userinfo")+" where username = ?", user2.Username)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
} }
func TestCombineTransactionSameMapper(t *testing.T) { func TestCombineTransactionSameMapper(t *testing.T) {
@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) {
counter() counter()
defer counter() defer counter()
session := testEngine.NewSession() session := testEngine.NewSession()
defer session.Close() defer session.Close()
err := session.Begin() err := session.Begin()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
return
}
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1) _, err = session.Insert(&user1)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
return
}
user2 := Userinfo{Username: "zzz"} user2 := Userinfo{Username: "zzz"}
_, err = session.Where("(id) = ?", 0).Update(&user2) _, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
return
}
_, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) _, err = session.Exec("delete from "+testEngine.TableNameWithSchema("`Userinfo`")+" where `Username` = ?", user2.Username)
if err != nil { assert.NoError(t, err)
session.Rollback()
t.Error(err)
panic(err)
return
}
err = session.Commit() err = session.Commit()
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
} }

View File

@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) {
col1 := &UpdateAllCols{Ptr: &s} col1 := &UpdateAllCols{Ptr: &s}
err = testEngine.Sync(col1) err = testEngine.Sync(col1)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
_, err = testEngine.Insert(col1) _, err = testEngine.Insert(col1)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
col2 := &UpdateAllCols{col1.Id, true, "", nil} col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2) _, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
col3 := &UpdateAllCols{} col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3) has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if !has { if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) {
func TestUpdateSameMapper(t *testing.T) { func TestUpdateSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
oldMapper := testEngine.GetColumnMapper() oldMapper := testEngine.GetTableMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type()) testEngine.UnMapType(rValue(new(Condi)).Type())
testEngine.UnMapType(rValue(new(Article)).Type()) testEngine.UnMapType(rValue(new(Article)).Type())
@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) {
var ori Userinfo var ori Userinfo
has, err := testEngine.Get(&ori) has, err := testEngine.Get(&ori)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.True(t, has)
panic(err)
}
if !has {
t.Error(errors.New("not exist"))
panic(errors.New("not exist"))
}
// update by id // update by id
user := Userinfo{Username: "xxx", Height: 1.2} user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.ID(ori.Uid).Update(&user) cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 1, cnt)
panic(err)
}
if cnt != 1 {
err = errors.New("update not returned 1")
t.Error(err)
panic(err)
return
}
condi := Condi{"Username": "zzz", "Departname": ""} condi := Condi{"Username": "zzz", "Departname": ""}
cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, 1, cnt)
panic(err)
}
if cnt != 1 {
err = errors.New("update not returned 1")
t.Error(err)
panic(err)
return
}
cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
total, err := testEngine.Count(&user) total, err := testEngine.Count(&user)
if err != nil { assert.NoError(t, err)
t.Error(err) assert.EqualValues(t, cnt, total)
panic(err)
}
if cnt != total {
err = errors.New("insert not returned 1")
t.Error(err)
panic(err)
return
}
err = testEngine.Sync(&Article{}) err = testEngine.Sync(&Article{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
defer func() { defer func() {
err = testEngine.DropTables(&Article{}) err = testEngine.DropTables(&Article{})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
}() }()
a := &Article{0, "1", "2", "3", "4", "5", 2} a := &Article{0, "1", "2", "3", "4", "5", 2}
cnt, err = testEngine.Insert(a) cnt, err = testEngine.Insert(a)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if cnt != 1 { if cnt != 1 {
err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) {
} }
cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if cnt != 1 { if cnt != 1 {
err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) {
col1 := &UpdateAllCols{} col1 := &UpdateAllCols{}
err = testEngine.Sync(col1) err = testEngine.Sync(col1)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
_, err = testEngine.Insert(col1) _, err = testEngine.Insert(col1)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
col2 := &UpdateAllCols{col1.Id, true, "", nil} col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2) _, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
col3 := &UpdateAllCols{} col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3) has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if !has { if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) {
{ {
col1 := &UpdateMustCols{} col1 := &UpdateMustCols{}
err = testEngine.Sync(col1) err = testEngine.Sync(col1)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
_, err = testEngine.Insert(col1) _, err = testEngine.Insert(col1)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
col2 := &UpdateMustCols{col1.Id, true, ""} col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.GetColumnMapper().Obj2Table("Bool") boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
stringStr := testEngine.GetColumnMapper().Obj2Table("String") stringStr := testEngine.GetColumnMapper().Obj2Table("String")
_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
col3 := &UpdateMustCols{} col3 := &UpdateMustCols{}
has, err := testEngine.ID(col2.Id).Get(col3) has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil { assert.NoError(t, err)
t.Error(err)
panic(err)
}
if !has { if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))

View File

@ -225,24 +225,6 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
return nil return nil
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v)
}
return statement
}
// Auto generating update columnes and values according a struct // Auto generating update columnes and values according a struct
func buildUpdates(engine *Engine, table *core.Table, bean interface{}, func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion bool, includeUpdated bool, includeNil bool,
@ -743,6 +725,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement return statement
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
}
statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(tableNameOrBean))
return statement
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf bytes.Buffer var buf bytes.Buffer
@ -752,54 +751,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP) fmt.Fprintf(&buf, "%v JOIN ", joinOP)
} }
switch tablename.(type) { statement.Engine.tbNameNoSchema(&buf, tablename)
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
}
}
}
if l > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table))
}
case TableName:
fmt.Fprintf(&buf, tablename.(TableName).TableName())
case string:
fmt.Fprintf(&buf, tablename.(string))
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
}
fmt.Fprintf(&buf, " ON %v", condition) fmt.Fprintf(&buf, " ON %v", condition)
statement.JoinStr = buf.String() statement.JoinStr = buf.String()
@ -915,7 +867,7 @@ func (statement *Statement) genDelIndexSQL() []string {
} }
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
if statement.Engine.dialect.IndexOnTable() { if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
} }
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }