fix postgres with schema

This commit is contained in:
Lunny Xiao 2018-04-09 14:40:30 +08:00
parent 707e65ee77
commit faa4602c0c
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
14 changed files with 275 additions and 492 deletions

View File

@ -895,6 +895,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
}
args := []interface{}{db.Schema, tableName}
return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args
}
@ -960,7 +961,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att
var f string
if len(db.Schema) != 0 {
args = append(args, db.Schema)
f = "AND s.table_schema = $2"
f = " AND s.table_schema = $2"
}
s = fmt.Sprintf(s, f)
@ -1085,11 +1086,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1")
db.LogSQL(s, args)
if len(db.Schema) != 0 {
args = append(args, db.Schema)
s = s + " AND schemaname=$2"
}
db.LogSQL(s, args)
rows, err := db.DB().Query(s, args...)
if err != nil {

View File

@ -536,46 +536,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
return nil
}
func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) {
v := rValue(beanOrTableName)
if v.Type().Kind() == reflect.String {
return beanOrTableName.(string), nil
} else if v.Type().Kind() == reflect.Struct {
return engine.tbName(v), nil
}
return "", errors.New("bean should be a struct or struct's point")
}
func (engine *Engine) tbSchemaName(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
func (engine *Engine) tbName(v reflect.Value) string {
if tb, ok := v.Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
if v.Type().Kind() == reflect.Ptr {
if tb, ok := reflect.Indirect(v).Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
} else if v.CanAddr() {
if tb, ok := v.Addr().Interface().(TableName); ok {
return engine.tbSchemaName(tb.TableName())
}
}
return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name()))
}
// Cascade use cascade or not
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
@ -895,20 +855,8 @@ var (
func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) {
t := v.Type()
table := engine.newTable()
if tb, ok := v.Interface().(TableName); ok {
table.Name = tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
table.Name = tb.TableName()
}
}
if table.Name == "" {
table.Name = engine.TableMapper.Obj2Table(t.Name())
}
}
table.Type = t
table.Name = engine.tbNameForMap(v)
var idFieldColName string
var hasCacheTag, hasNoCacheTag bool
@ -1237,13 +1185,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
for _, bean := range beans {
v := rValue(bean)
tableName := engine.tbName(v)
tableNameNoSchema := engine.tbNameNoSchemaString(v.Interface())
table, err := engine.autoMapType(v)
if err != nil {
return err
}
isExist, err := session.Table(bean).isTableExist(tableName)
isExist, err := session.Table(bean).isTableExist(tableNameNoSchema)
if err != nil {
return err
}
@ -1269,7 +1217,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
} else {
for _, col := range table.Columns() {
isExist, err := engine.dialect.IsColumnExist(tableName, col.Name)
isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name)
if err != nil {
return err
}
@ -1289,7 +1237,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
if index.Type == core.UniqueType {
isExist, err := session.isIndexExist2(tableName, index.Cols, true)
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true)
if err != nil {
return err
}
@ -1298,13 +1246,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
err = session.addUnique(tableName, name)
err = session.addUnique(tableNameNoSchema, name)
if err != nil {
return err
}
}
} else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(tableName, index.Cols, false)
isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false)
if err != nil {
return err
}
@ -1313,7 +1261,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
err = session.addIndex(tableName, name)
err = session.addIndex(tableNameNoSchema, name)
if err != nil {
return err
}

103
engine_table.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"bytes"
"fmt"
"io"
"reflect"
"strings"
"github.com/go-xorm/core"
)
// TableNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) TableNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
engine.dialect.URI().Schema != "" &&
engine.dialect.URI().Schema != postgresPublicSchema &&
strings.Index(v, ".") == -1 {
return engine.dialect.URI().Schema + "." + v
}
return v
}
func (engine *Engine) tbName(v reflect.Value) string {
return engine.TableNameWithSchema(engine.tbNameNoSchemaString(v.Interface()))
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {
t := v.Type()
if tb, ok := v.Interface().(TableName); ok {
return tb.TableName()
} else {
if v.CanAddr() {
if tb, ok = v.Addr().Interface().(TableName); ok {
return tb.TableName()
}
}
}
return engine.TableMapper.Obj2Table(t.Name())
}
func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) {
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(w, engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name()))
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f)))
}
}
}
if l > 1 {
fmt.Fprintf(w, "%v AS %v", engine.Quote(table),
engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(w, engine.Quote(table))
}
case TableName:
fmt.Fprintf(w, tablename.(TableName).TableName())
case string:
fmt.Fprintf(w, tablename.(string))
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name()))
} else {
fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename)))
}
}
}
func (engine *Engine) tbNameNoSchemaString(tablename interface{}) string {
var buf bytes.Buffer
engine.tbNameNoSchema(&buf, tablename)
return buf.String()
}

View File

@ -95,6 +95,7 @@ type EngineInterface interface {
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
TableNameWithSchema(string) string
UnMapType(reflect.Type)
}

View File

@ -54,7 +54,11 @@ func TestRows(t *testing.T) {
}
assert.EqualValues(t, 1, cnt)
rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows))
var tbName = testEngine.Quote("user_rows")
if testEngine.Dialect().URI().Schema != "" {
tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName
}
rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
assert.NoError(t, err)
defer rows2.Close()

View File

@ -122,18 +122,11 @@ func TestIn(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, 3, cnt)
department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
var usrs []Userinfo
err = testEngine.Limit(3).Find(&usrs)
if err != nil {
t.Error(err)
panic(err)
}
if len(usrs) != 3 {
err = errors.New("there are not 3 records")
t.Error(err)
panic(err)
}
err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs)
assert.Error(t, err)
assert.EqualValues(t, 3, len(usrs))
var ids []int64
var idsStr string
@ -145,35 +138,20 @@ func TestIn(t *testing.T) {
users := make([]Userinfo, 0)
err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
assert.EqualValues(t, 3, len(users))
users = make([]Userinfo, 0)
err = testEngine.In("(id)", ids).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
assert.EqualValues(t, 3, len(users))
for _, user := range users {
if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
assert.NoError(t, err)
}
}
@ -183,87 +161,41 @@ func TestIn(t *testing.T) {
idsInterface = append(idsInterface, id)
}
department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`"
err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
if len(users) != 3 {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
}
assert.EqualValues(t, 3, len(users))
for _, user := range users {
if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] {
err = errors.New("in uses should be " + idsStr + " total 3")
t.Error(err)
panic(err)
assert.NoError(t, err)
}
}
dev := testEngine.GetColumnMapper().Obj2Table("Dev")
err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
user := new(Userinfo)
has, err := testEngine.ID(ids[0]).Get(user)
if err != nil {
t.Error(err)
panic(err)
}
if !has {
err = errors.New("get record not 1")
t.Error(err)
panic(err)
}
if user.Departname != "dev-" {
err = errors.New("update not success")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, "dev-", user.Departname)
cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update records not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{})
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("deleted records not 1")
t.Error(err)
panic(err)
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}
func TestFindAndCount(t *testing.T) {

View File

@ -96,21 +96,19 @@ func TestFind(t *testing.T) {
users := make([]Userinfo, 0)
err := testEngine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
for _, user := range users {
fmt.Println(user)
}
users2 := make([]Userinfo, 0)
userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo")
err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2)
if err != nil {
t.Error(err)
panic(err)
var tbName = testEngine.Quote(userinfo)
if testEngine.Dialect().URI().Schema != "" {
tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName
}
err = testEngine.SQL("select * from " + tbName).Find(&users2)
assert.NoError(t, err)
}
func TestFind2(t *testing.T) {
@ -238,14 +236,8 @@ func TestDistinct(t *testing.T) {
users := make([]Userinfo, 0)
departname := testEngine.GetTableMapper().Obj2Table("Departname")
err = testEngine.Distinct(departname).Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
if len(users) != 1 {
t.Error(err)
panic(errors.New("should be one record"))
}
assert.NoError(t, err)
assert.EqualValues(t, 2, len(users))
fmt.Println(users)
@ -255,11 +247,9 @@ func TestDistinct(t *testing.T) {
users2 := make([]Depart, 0)
err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
if len(users2) != 1 {
assert.NoError(t, err)
if len(users2) != 2 {
fmt.Println(len(users2))
t.Error(err)
panic(errors.New("should be one record"))
}
@ -272,18 +262,12 @@ func TestOrder(t *testing.T) {
users := make([]Userinfo, 0)
err := testEngine.OrderBy("id desc").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
users2 := make([]Userinfo, 0)
err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users2)
}
@ -293,10 +277,7 @@ func TestHaving(t *testing.T) {
users := make([]Userinfo, 0)
err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
/*users = make([]Userinfo, 0)
@ -324,18 +305,12 @@ func TestOrderSameMapper(t *testing.T) {
users := make([]Userinfo, 0)
err := testEngine.OrderBy("(id) desc").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users)
users2 := make([]Userinfo, 0)
err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
fmt.Println(users2)
}

View File

@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) {
}
assert.NoError(t, prepareEngine())
assertSync(t, new(TaskSolution))
assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
tables, err := testEngine.DBMetas()
tables1, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
pkCols := tables[0].PKColumns()
assertSync(t, new(TaskSolution))
assert.NoError(t, testEngine.Sync2(new(TaskSolution)))
tables2, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1+len(tables1), len(tables2))
var table *core.Table
for _, t := range tables2 {
if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") {
table = t
break
}
}
assert.NotEqual(t, nil, table)
pkCols := table.PKColumns()
assert.EqualValues(t, 2, len(pkCols))
assert.EqualValues(t, "uid", pkCols[0].Name)
assert.EqualValues(t, "tid", pkCols[1].Name)

View File

@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) {
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
records, err := testEngine.QueryString("select * from get_var2")
records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var2"))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) {
_, err := testEngine.Insert(data)
assert.NoError(t, err)
records, err := testEngine.QueryString("select * from get_var3")
records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var3"))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 2, len(records[0]))
@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) {
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
records, err := testEngine.QueryInterface("select * from get_var_interface")
records, err := testEngine.QueryInterface("select * from " + testEngine.TableNameWithSchema("get_var_interface"))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, err)
assertResult(t, results)
results, err = testEngine.SQL("select * from query_no_params").Query()
results, err = testEngine.SQL("select * from " + testEngine.TableNameWithSchema("query_no_params")).Query()
assert.NoError(t, err)
assertResult(t, results)
}
@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) {
assert.EqualValues(t, 3000, money)
}
results, err := testEngine.Query(builder.Select("*").From("query_with_builder"))
results, err := testEngine.Query(builder.Select("*").From(testEngine.TableNameWithSchema("query_with_builder")))
assert.NoError(t, err)
assertResult(t, results)
}

View File

@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) {
assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))
res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user")
res, err := testEngine.Exec("INSERT INTO "+testEngine.TableNameWithSchema("`userinfo_query`")+" (uid, name) VALUES (?, ?)", 1, "user")
assert.NoError(t, err)
cnt, err := res.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
results, err := testEngine.Query("select * from userinfo_query")
results, err := testEngine.Query("select * from " + testEngine.TableNameWithSchema("userinfo_query"))
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
id, err := strconv.Atoi(string(results[0]["uid"]))

View File

@ -128,14 +128,12 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return err
}
tableName := session.engine.tbNameNoSchemaString(beanOrTableName)
var needDrop = true
if !session.engine.dialect.SupportDropIfExists() {
fmt.Println("TableCheckSql:", tableName)
sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
fmt.Println("sqlStr:", sqlStr)
results, err := session.queryBytes(sqlStr, args...)
if err != nil {
return err
@ -144,8 +142,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
}
if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(tableName)
_, err = session.exec(sqlStr)
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableNameWithSchema(tableName))
_, err := session.exec(sqlStr)
return err
}
return nil
@ -157,10 +155,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error)
defer session.Close()
}
tableName, err := session.engine.tableName(beanOrTableName)
if err != nil {
return false, err
}
tableName := session.engine.tbNameNoSchemaString(beanOrTableName)
return session.isTableExist(tableName)
}
@ -190,7 +185,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.TableNameWithSchema(session.engine.Quote(tableName)))
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
@ -270,7 +265,8 @@ func (session *Session) Sync2(beans ...interface{}) error {
return err
}
structTables = append(structTables, table)
var tbName = session.tbNameNoSchema(table)
tbName := session.tbNameNoSchema(table)
tbNameWithSchema := engine.TableNameWithSchema(tbName)
var oriTable *core.Table
for _, tb := range tables {
@ -315,32 +311,32 @@ func (session *Session) Sync2(beans ...interface{}) error {
if engine.dialect.DBType() == core.MYSQL ||
engine.dialect.DBType() == core.POSTGRES {
engine.logger.Infof("Table %s column %s change type from %s to %s\n",
tbName, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, curType, expectedType)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
} else {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
tbName, col.Name, curType, expectedType)
tbNameWithSchema, col.Name, curType, expectedType)
}
} else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
}
}
} else {
if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
tbName, col.Name, curType, expectedType)
tbNameWithSchema, col.Name, curType, expectedType)
}
}
} else if expectedType == core.Varchar {
if engine.dialect.DBType() == core.MYSQL {
if oriCol.Length < col.Length {
engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
tbName, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
tbNameWithSchema, col.Name, oriCol.Length, col.Length)
_, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
}
}
}
@ -354,7 +350,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
} else {
session.statement.RefTable = table
session.statement.tableName = tbName
session.statement.tableName = tbNameWithSchema
err = session.addColumn(col.Name)
}
if err != nil {
@ -377,7 +373,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
if oriIndex != nil {
if oriIndex.Type != index.Type {
sql := engine.dialect.DropIndexSql(tbName, oriIndex)
sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
_, err = session.exec(sql)
if err != nil {
return err
@ -393,7 +389,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name2, index2 := range oriTable.Indexes {
if _, ok := foundIndexNames[name2]; !ok {
sql := engine.dialect.DropIndexSql(tbName, index2)
sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
_, err = session.exec(sql)
if err != nil {
return err
@ -404,12 +400,12 @@ func (session *Session) Sync2(beans ...interface{}) error {
for name, index := range addedNames {
if index.Type == core.UniqueType {
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addUnique(tbName, name)
session.statement.tableName = tbNameWithSchema
err = session.addUnique(tbNameWithSchema, name)
} else if index.Type == core.IndexType {
session.statement.RefTable = table
session.statement.tableName = tbName
err = session.addIndex(tbName, name)
session.statement.tableName = tbNameWithSchema
err = session.addIndex(tbNameWithSchema, name)
}
if err != nil {
return err
@ -434,7 +430,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableNameWithSchema(table.Name), colName)
}
}
}

View File

@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) {
defer session.Close()
err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
user2 := Userinfo{Username: "yyy"}
_, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
fmt.Println(err)
//t.Error(err)
return
}
assert.NoError(t, err)
_, err = session.Delete(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
return
}
// panic(err) !nashtsai! should remove this
assert.NoError(t, err)
}
func TestCombineTransaction(t *testing.T) {
@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) {
defer session.Close()
err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
assert.NoError(t, err)
user2 := Userinfo{Username: "zzz"}
_, err = session.Where("id = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
assert.NoError(t, err)
_, err = session.Exec("delete from userinfo where username = ?", user2.Username)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
}
_, err = session.Exec("delete from "+testEngine.TableNameWithSchema("userinfo")+" where username = ?", user2.Username)
assert.NoError(t, err)
err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
}
func TestCombineTransactionSameMapper(t *testing.T) {
@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) {
counter()
defer counter()
session := testEngine.NewSession()
defer session.Close()
err := session.Begin()
if err != nil {
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
user2 := Userinfo{Username: "zzz"}
_, err = session.Where("(id) = ?", 0).Update(&user2)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
_, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username)
if err != nil {
session.Rollback()
t.Error(err)
panic(err)
return
}
_, err = session.Exec("delete from "+testEngine.TableNameWithSchema("`Userinfo`")+" where `Username` = ?", user2.Username)
assert.NoError(t, err)
err = session.Commit()
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
}

View File

@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) {
col1 := &UpdateAllCols{Ptr: &s}
err = testEngine.Sync(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
_, err = testEngine.Insert(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) {
func TestUpdateSameMapper(t *testing.T) {
assert.NoError(t, prepareEngine())
oldMapper := testEngine.GetColumnMapper()
oldMapper := testEngine.GetTableMapper()
testEngine.UnMapType(rValue(new(Userinfo)).Type())
testEngine.UnMapType(rValue(new(Condi)).Type())
testEngine.UnMapType(rValue(new(Article)).Type())
@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) {
var ori Userinfo
has, err := testEngine.Get(&ori)
if err != nil {
t.Error(err)
panic(err)
}
if !has {
t.Error(errors.New("not exist"))
panic(errors.New("not exist"))
}
assert.NoError(t, err)
assert.True(t, has)
// update by id
user := Userinfo{Username: "xxx", Height: 1.2}
cnt, err := testEngine.ID(ori.Uid).Update(&user)
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update not returned 1")
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
condi := Condi{"Username": "zzz", "Departname": ""}
cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi)
if err != nil {
t.Error(err)
panic(err)
}
if cnt != 1 {
err = errors.New("update not returned 1")
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
total, err := testEngine.Count(&user)
if err != nil {
t.Error(err)
panic(err)
}
if cnt != total {
err = errors.New("insert not returned 1")
t.Error(err)
panic(err)
return
}
assert.NoError(t, err)
assert.EqualValues(t, cnt, total)
err = testEngine.Sync(&Article{})
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
defer func() {
err = testEngine.DropTables(&Article{})
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
}()
a := &Article{0, "1", "2", "3", "4", "5", 2}
cnt, err = testEngine.Insert(a)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if cnt != 1 {
err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) {
}
cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"})
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if cnt != 1 {
err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt))
@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) {
col1 := &UpdateAllCols{}
err = testEngine.Sync(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
_, err = testEngine.Insert(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
col2 := &UpdateAllCols{col1.Id, true, "", nil}
_, err = testEngine.ID(col2.Id).AllCols().Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
col3 := &UpdateAllCols{}
has, err = testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))
@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) {
{
col1 := &UpdateMustCols{}
err = testEngine.Sync(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
_, err = testEngine.Insert(col1)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
col2 := &UpdateMustCols{col1.Id, true, ""}
boolStr := testEngine.GetColumnMapper().Obj2Table("Bool")
stringStr := testEngine.GetColumnMapper().Obj2Table("String")
_, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
col3 := &UpdateMustCols{}
has, err := testEngine.ID(col2.Id).Get(col3)
if err != nil {
t.Error(err)
panic(err)
}
assert.NoError(t, err)
if !has {
err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id))

View File

@ -225,24 +225,6 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
return nil
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.String {
statement.AltTableName = tableNameOrBean.(string)
} else if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
statement.AltTableName = statement.Engine.tbName(v)
}
return statement
}
// Auto generating update columnes and values according a struct
func buildUpdates(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool,
@ -743,6 +725,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement
}
// Table tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
v := rValue(tableNameOrBean)
t := v.Type()
if t.Kind() == reflect.Struct {
var err error
statement.RefTable, err = statement.Engine.autoMapType(v)
if err != nil {
statement.Engine.logger.Error(err)
return statement
}
}
statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(tableNameOrBean))
return statement
}
// Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
var buf bytes.Buffer
@ -752,54 +751,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
switch tablename.(type) {
case []string:
t := tablename.([]string)
if len(t) > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1]))
} else if len(t) == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(t[0]))
}
case []interface{}:
t := tablename.([]interface{})
l := len(t)
var table string
if l > 0 {
f := t[0]
switch f.(type) {
case string:
table = f.(string)
case TableName:
table = f.(TableName).TableName()
default:
v := rValue(f)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f)))
}
}
}
if l > 1 {
fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table),
statement.Engine.Quote(fmt.Sprintf("%v", t[1])))
} else if l == 1 {
fmt.Fprintf(&buf, statement.Engine.Quote(table))
}
case TableName:
fmt.Fprintf(&buf, tablename.(TableName).TableName())
case string:
fmt.Fprintf(&buf, tablename.(string))
default:
v := rValue(tablename)
t := v.Type()
if t.Kind() == reflect.Struct {
fmt.Fprintf(&buf, statement.Engine.tbName(v))
} else {
fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename)))
}
}
statement.Engine.tbNameNoSchema(&buf, tablename)
fmt.Fprintf(&buf, " ON %v", condition)
statement.JoinStr = buf.String()
@ -915,7 +867,7 @@ func (statement *Statement) genDelIndexSQL() []string {
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName))
if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName()))
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
}
sqls = append(sqls, sql)
}