diff --git a/dialect_postgres.go b/dialect_postgres.go index f2858f19..b8c8308e 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -895,6 +895,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { args := []interface{}{tableName} return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args } + args := []interface{}{db.Schema, tableName} return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } @@ -960,7 +961,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att var f string if len(db.Schema) != 0 { args = append(args, db.Schema) - f = "AND s.table_schema = $2" + f = " AND s.table_schema = $2" } s = fmt.Sprintf(s, f) @@ -1085,11 +1086,11 @@ func (db *postgres) GetTables() ([]*core.Table, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") - db.LogSQL(s, args) if len(db.Schema) != 0 { args = append(args, db.Schema) s = s + " AND schemaname=$2" } + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) if err != nil { diff --git a/engine.go b/engine.go index 95a0184a..35817f3f 100644 --- a/engine.go +++ b/engine.go @@ -536,46 +536,6 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D return nil } -func (engine *Engine) tableName(beanOrTableName interface{}) (string, error) { - v := rValue(beanOrTableName) - if v.Type().Kind() == reflect.String { - return beanOrTableName.(string), nil - } else if v.Type().Kind() == reflect.Struct { - return engine.tbName(v), nil - } - return "", errors.New("bean should be a struct or struct's point") -} - -func (engine *Engine) tbSchemaName(v string) string { - // Add schema name as prefix of table name. - // Only for postgres database. - if engine.dialect.DBType() == core.POSTGRES && - engine.dialect.URI().Schema != "" && - engine.dialect.URI().Schema != postgresPublicSchema && - strings.Index(v, ".") == -1 { - return engine.dialect.URI().Schema + "." + v - } - return v -} - -func (engine *Engine) tbName(v reflect.Value) string { - if tb, ok := v.Interface().(TableName); ok { - return engine.tbSchemaName(tb.TableName()) - - } - - if v.Type().Kind() == reflect.Ptr { - if tb, ok := reflect.Indirect(v).Interface().(TableName); ok { - return engine.tbSchemaName(tb.TableName()) - } - } else if v.CanAddr() { - if tb, ok := v.Addr().Interface().(TableName); ok { - return engine.tbSchemaName(tb.TableName()) - } - } - return engine.tbSchemaName(engine.TableMapper.Obj2Table(reflect.Indirect(v).Type().Name())) -} - // Cascade use cascade or not func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() @@ -895,20 +855,8 @@ var ( func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { t := v.Type() table := engine.newTable() - if tb, ok := v.Interface().(TableName); ok { - table.Name = tb.TableName() - } else { - if v.CanAddr() { - if tb, ok = v.Addr().Interface().(TableName); ok { - table.Name = tb.TableName() - } - } - if table.Name == "" { - table.Name = engine.TableMapper.Obj2Table(t.Name()) - } - } - table.Type = t + table.Name = engine.tbNameForMap(v) var idFieldColName string var hasCacheTag, hasNoCacheTag bool @@ -1237,13 +1185,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := rValue(bean) - tableName := engine.tbName(v) + tableNameNoSchema := engine.tbNameNoSchemaString(v.Interface()) table, err := engine.autoMapType(v) if err != nil { return err } - isExist, err := session.Table(bean).isTableExist(tableName) + isExist, err := session.Table(bean).isTableExist(tableNameNoSchema) if err != nil { return err } @@ -1269,7 +1217,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { } } else { for _, col := range table.Columns() { - isExist, err := engine.dialect.IsColumnExist(tableName, col.Name) + isExist, err := engine.dialect.IsColumnExist(tableNameNoSchema, col.Name) if err != nil { return err } @@ -1289,7 +1237,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if index.Type == core.UniqueType { - isExist, err := session.isIndexExist2(tableName, index.Cols, true) + isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, true) if err != nil { return err } @@ -1298,13 +1246,13 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } - err = session.addUnique(tableName, name) + err = session.addUnique(tableNameNoSchema, name) if err != nil { return err } } } else if index.Type == core.IndexType { - isExist, err := session.isIndexExist2(tableName, index.Cols, false) + isExist, err := session.isIndexExist2(tableNameNoSchema, index.Cols, false) if err != nil { return err } @@ -1313,7 +1261,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } - err = session.addIndex(tableName, name) + err = session.addIndex(tableNameNoSchema, name) if err != nil { return err } diff --git a/engine_table.go b/engine_table.go new file mode 100644 index 00000000..bb74ac16 --- /dev/null +++ b/engine_table.go @@ -0,0 +1,103 @@ +// Copyright 2018 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + + "github.com/go-xorm/core" +) + +// TableNameWithSchema will automatically add schema prefix on table name +func (engine *Engine) TableNameWithSchema(v string) string { + // Add schema name as prefix of table name. + // Only for postgres database. + if engine.dialect.DBType() == core.POSTGRES && + engine.dialect.URI().Schema != "" && + engine.dialect.URI().Schema != postgresPublicSchema && + strings.Index(v, ".") == -1 { + return engine.dialect.URI().Schema + "." + v + } + return v +} + +func (engine *Engine) tbName(v reflect.Value) string { + return engine.TableNameWithSchema(engine.tbNameNoSchemaString(v.Interface())) +} + +func (engine *Engine) tbNameForMap(v reflect.Value) string { + t := v.Type() + if tb, ok := v.Interface().(TableName); ok { + return tb.TableName() + } else { + if v.CanAddr() { + if tb, ok = v.Addr().Interface().(TableName); ok { + return tb.TableName() + } + } + } + return engine.TableMapper.Obj2Table(t.Name()) +} + +func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) { + switch tablename.(type) { + case []string: + t := tablename.([]string) + if len(t) > 1 { + fmt.Fprintf(w, "%v AS %v", engine.Quote(t[0]), engine.Quote(t[1])) + } else if len(t) == 1 { + fmt.Fprintf(w, engine.Quote(t[0])) + } + case []interface{}: + t := tablename.([]interface{}) + l := len(t) + var table string + if l > 0 { + f := t[0] + switch f.(type) { + case string: + table = f.(string) + case TableName: + table = f.(TableName).TableName() + default: + v := rValue(f) + t := v.Type() + if t.Kind() == reflect.Struct { + fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name())) + } else { + fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f))) + } + } + } + if l > 1 { + fmt.Fprintf(w, "%v AS %v", engine.Quote(table), + engine.Quote(fmt.Sprintf("%v", t[1]))) + } else if l == 1 { + fmt.Fprintf(w, engine.Quote(table)) + } + case TableName: + fmt.Fprintf(w, tablename.(TableName).TableName()) + case string: + fmt.Fprintf(w, tablename.(string)) + default: + v := rValue(tablename) + t := v.Type() + if t.Kind() == reflect.Struct { + fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name())) + } else { + fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename))) + } + } +} + +func (engine *Engine) tbNameNoSchemaString(tablename interface{}) string { + var buf bytes.Buffer + engine.tbNameNoSchema(&buf, tablename) + return buf.String() +} diff --git a/interface.go b/interface.go index 42a8fe25..ecc2c4f7 100644 --- a/interface.go +++ b/interface.go @@ -95,6 +95,7 @@ type EngineInterface interface { Sync2(...interface{}) error StoreEngine(storeEngine string) *Session TableInfo(bean interface{}) *Table + TableNameWithSchema(string) string UnMapType(reflect.Type) } diff --git a/rows_test.go b/rows_test.go index c48938a9..ebfabef0 100644 --- a/rows_test.go +++ b/rows_test.go @@ -54,7 +54,11 @@ func TestRows(t *testing.T) { } assert.EqualValues(t, 1, cnt) - rows2, err := testEngine.SQL("SELECT * FROM user_rows").Rows(new(UserRows)) + var tbName = testEngine.Quote("user_rows") + if testEngine.Dialect().URI().Schema != "" { + tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName + } + rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) assert.NoError(t, err) defer rows2.Close() diff --git a/session_cond_test.go b/session_cond_test.go index a80e7d03..4c4f4733 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -122,18 +122,11 @@ func TestIn(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 3, cnt) + department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`" var usrs []Userinfo - err = testEngine.Limit(3).Find(&usrs) - if err != nil { - t.Error(err) - panic(err) - } - - if len(usrs) != 3 { - err = errors.New("there are not 3 records") - t.Error(err) - panic(err) - } + err = testEngine.Where(department+" = ?", "dev").Limit(3).Find(&usrs) + assert.Error(t, err) + assert.EqualValues(t, 3, len(usrs)) var ids []int64 var idsStr string @@ -145,35 +138,20 @@ func TestIn(t *testing.T) { users := make([]Userinfo, 0) err = testEngine.In("(id)", ids[0], ids[1], ids[2]).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) - if len(users) != 3 { - err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 3, len(users)) users = make([]Userinfo, 0) err = testEngine.In("(id)", ids).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) - if len(users) != 3 { - err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 3, len(users)) for _, user := range users { if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) + assert.NoError(t, err) } } @@ -183,87 +161,41 @@ func TestIn(t *testing.T) { idsInterface = append(idsInterface, id) } - department := "`" + testEngine.GetColumnMapper().Obj2Table("Departname") + "`" err = testEngine.Where(department+" = ?", "dev").In("(id)", idsInterface...).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) - - if len(users) != 3 { - err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) - } + assert.EqualValues(t, 3, len(users)) for _, user := range users { if user.Uid != ids[0] && user.Uid != ids[1] && user.Uid != ids[2] { err = errors.New("in uses should be " + idsStr + " total 3") - t.Error(err) - panic(err) + assert.NoError(t, err) } } dev := testEngine.GetColumnMapper().Obj2Table("Dev") err = testEngine.In("(id)", 1).In("(id)", 2).In(department, dev).Find(&users) - - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev-"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update records not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) user := new(Userinfo) has, err := testEngine.ID(ids[0]).Get(user) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - err = errors.New("get record not 1") - t.Error(err) - panic(err) - } - if user.Departname != "dev-" { - err = errors.New("update not success") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "dev-", user.Departname) cnt, err = testEngine.In("(id)", ids[0]).Update(&Userinfo{Departname: "dev"}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update records not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) cnt, err = testEngine.In("(id)", ids[1]).Delete(&Userinfo{}) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("deleted records not 1") - t.Error(err) - panic(err) - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) } func TestFindAndCount(t *testing.T) { diff --git a/session_find_test.go b/session_find_test.go index 04fdb030..de000de5 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -96,21 +96,19 @@ func TestFind(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) for _, user := range users { fmt.Println(user) } users2 := make([]Userinfo, 0) userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo") - err = testEngine.SQL("select * from " + testEngine.Quote(userinfo)).Find(&users2) - if err != nil { - t.Error(err) - panic(err) + var tbName = testEngine.Quote(userinfo) + if testEngine.Dialect().URI().Schema != "" { + tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName } + err = testEngine.SQL("select * from " + tbName).Find(&users2) + assert.NoError(t, err) } func TestFind2(t *testing.T) { @@ -238,14 +236,8 @@ func TestDistinct(t *testing.T) { users := make([]Userinfo, 0) departname := testEngine.GetTableMapper().Obj2Table("Departname") err = testEngine.Distinct(departname).Find(&users) - if err != nil { - t.Error(err) - panic(err) - } - if len(users) != 1 { - t.Error(err) - panic(errors.New("should be one record")) - } + assert.NoError(t, err) + assert.EqualValues(t, 2, len(users)) fmt.Println(users) @@ -255,11 +247,9 @@ func TestDistinct(t *testing.T) { users2 := make([]Depart, 0) err = testEngine.Distinct(departname).Table(new(Userinfo)).Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } - if len(users2) != 1 { + assert.NoError(t, err) + if len(users2) != 2 { + fmt.Println(len(users2)) t.Error(err) panic(errors.New("should be one record")) } @@ -272,18 +262,12 @@ func TestOrder(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.OrderBy("id desc").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) users2 := make([]Userinfo, 0) err = testEngine.Asc("id", "username").Desc("height").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users2) } @@ -293,10 +277,7 @@ func TestHaving(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.GroupBy("username").Having("username='xlw'").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) /*users = make([]Userinfo, 0) @@ -324,18 +305,12 @@ func TestOrderSameMapper(t *testing.T) { users := make([]Userinfo, 0) err := testEngine.OrderBy("(id) desc").Find(&users) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users) users2 := make([]Userinfo, 0) err = testEngine.Asc("(id)", "Username").Desc("Height").Find(&users2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) fmt.Println(users2) } diff --git a/session_pk_test.go b/session_pk_test.go index 3370b2ad..7b025acd 100644 --- a/session_pk_test.go +++ b/session_pk_test.go @@ -1118,13 +1118,28 @@ func TestCompositePK(t *testing.T) { } assert.NoError(t, prepareEngine()) - assertSync(t, new(TaskSolution)) - assert.NoError(t, testEngine.Sync2(new(TaskSolution))) - tables, err := testEngine.DBMetas() + tables1, err := testEngine.DBMetas() assert.NoError(t, err) - assert.EqualValues(t, 1, len(tables)) - pkCols := tables[0].PKColumns() + + assertSync(t, new(TaskSolution)) + assert.NoError(t, testEngine.Sync2(new(TaskSolution))) + + tables2, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1+len(tables1), len(tables2)) + + var table *core.Table + for _, t := range tables2 { + if t.Name == testEngine.GetTableMapper().Obj2Table("TaskSolution") { + table = t + break + } + } + + assert.NotEqual(t, nil, table) + + pkCols := table.PKColumns() assert.EqualValues(t, 2, len(pkCols)) assert.EqualValues(t, "uid", pkCols[0].Name) assert.EqualValues(t, "tid", pkCols[1].Name) diff --git a/session_query_test.go b/session_query_test.go index 7ea413bb..59ce386c 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from get_var2") + records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var2")) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) { _, err := testEngine.Insert(data) assert.NoError(t, err) - records, err := testEngine.QueryString("select * from get_var3") + records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var3")) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 2, len(records[0])) @@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) { _, err := testEngine.InsertOne(data) assert.NoError(t, err) - records, err := testEngine.QueryInterface("select * from get_var_interface") + records, err := testEngine.QueryInterface("select * from " + testEngine.TableNameWithSchema("get_var_interface")) assert.NoError(t, err) assert.Equal(t, 1, len(records)) assert.Equal(t, 5, len(records[0])) @@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, err) assertResult(t, results) - results, err = testEngine.SQL("select * from query_no_params").Query() + results, err = testEngine.SQL("select * from " + testEngine.TableNameWithSchema("query_no_params")).Query() assert.NoError(t, err) assertResult(t, results) } @@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) { assert.EqualValues(t, 3000, money) } - results, err := testEngine.Query(builder.Select("*").From("query_with_builder")) + results, err := testEngine.Query(builder.Select("*").From(testEngine.TableNameWithSchema("query_with_builder"))) assert.NoError(t, err) assertResult(t, results) } diff --git a/session_raw_test.go b/session_raw_test.go index 32e8037c..6fafd848 100644 --- a/session_raw_test.go +++ b/session_raw_test.go @@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoQuery))) - res, err := testEngine.Exec("INSERT INTO `userinfo_query` (uid, name) VALUES (?, ?)", 1, "user") + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableNameWithSchema("`userinfo_query`")+" (uid, name) VALUES (?, ?)", 1, "user") assert.NoError(t, err) cnt, err := res.RowsAffected() assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - results, err := testEngine.Query("select * from userinfo_query") + results, err := testEngine.Query("select * from " + testEngine.TableNameWithSchema("userinfo_query")) assert.NoError(t, err) assert.EqualValues(t, 1, len(results)) id, err := strconv.Atoi(string(results[0]["uid"])) diff --git a/session_schema.go b/session_schema.go index 9d9edca8..e079c6ca 100644 --- a/session_schema.go +++ b/session_schema.go @@ -128,14 +128,12 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } func (session *Session) dropTable(beanOrTableName interface{}) error { - tableName, err := session.engine.tableName(beanOrTableName) - if err != nil { - return err - } - + tableName := session.engine.tbNameNoSchemaString(beanOrTableName) var needDrop = true if !session.engine.dialect.SupportDropIfExists() { + fmt.Println("TableCheckSql:", tableName) sqlStr, args := session.engine.dialect.TableCheckSql(tableName) + fmt.Println("sqlStr:", sqlStr) results, err := session.queryBytes(sqlStr, args...) if err != nil { return err @@ -144,8 +142,8 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { } if needDrop { - sqlStr := session.engine.Dialect().DropTableSql(tableName) - _, err = session.exec(sqlStr) + sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableNameWithSchema(tableName)) + _, err := session.exec(sqlStr) return err } return nil @@ -157,10 +155,7 @@ func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) defer session.Close() } - tableName, err := session.engine.tableName(beanOrTableName) - if err != nil { - return false, err - } + tableName := session.engine.tbNameNoSchemaString(beanOrTableName) return session.isTableExist(tableName) } @@ -190,7 +185,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 - sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName)) + sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.TableNameWithSchema(session.engine.Quote(tableName))) err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { @@ -270,7 +265,8 @@ func (session *Session) Sync2(beans ...interface{}) error { return err } structTables = append(structTables, table) - var tbName = session.tbNameNoSchema(table) + tbName := session.tbNameNoSchema(table) + tbNameWithSchema := engine.TableNameWithSchema(tbName) var oriTable *core.Table for _, tb := range tables { @@ -315,32 +311,32 @@ func (session *Session) Sync2(beans ...interface{}) error { if engine.dialect.DBType() == core.MYSQL || engine.dialect.DBType() == core.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", - tbName, col.Name, curType, expectedType) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, curType, expectedType) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", - tbName, col.Name, curType, expectedType) + tbNameWithSchema, col.Name, curType, expectedType) } } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { if engine.dialect.DBType() == core.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbName, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } } } else { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", - tbName, col.Name, curType, expectedType) + tbNameWithSchema, col.Name, curType, expectedType) } } } else if expectedType == core.Varchar { if engine.dialect.DBType() == core.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", - tbName, col.Name, oriCol.Length, col.Length) - _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col)) + tbNameWithSchema, col.Name, oriCol.Length, col.Length) + _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col)) } } } @@ -354,7 +350,7 @@ func (session *Session) Sync2(beans ...interface{}) error { } } else { session.statement.RefTable = table - session.statement.tableName = tbName + session.statement.tableName = tbNameWithSchema err = session.addColumn(col.Name) } if err != nil { @@ -377,7 +373,7 @@ func (session *Session) Sync2(beans ...interface{}) error { if oriIndex != nil { if oriIndex.Type != index.Type { - sql := engine.dialect.DropIndexSql(tbName, oriIndex) + sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex) _, err = session.exec(sql) if err != nil { return err @@ -393,7 +389,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { - sql := engine.dialect.DropIndexSql(tbName, index2) + sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { return err @@ -404,12 +400,12 @@ func (session *Session) Sync2(beans ...interface{}) error { for name, index := range addedNames { if index.Type == core.UniqueType { session.statement.RefTable = table - session.statement.tableName = tbName - err = session.addUnique(tbName, name) + session.statement.tableName = tbNameWithSchema + err = session.addUnique(tbNameWithSchema, name) } else if index.Type == core.IndexType { session.statement.RefTable = table - session.statement.tableName = tbName - err = session.addIndex(tbName, name) + session.statement.tableName = tbNameWithSchema + err = session.addIndex(tbNameWithSchema, name) } if err != nil { return err @@ -434,7 +430,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, colName := range table.ColumnsSeq() { if oriTable.GetColumn(colName) == nil { - engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName) + engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableNameWithSchema(table.Name), colName) } } } diff --git a/session_tx_test.go b/session_tx_test.go index 7102f5c7..5f2b5817 100644 --- a/session_tx_test.go +++ b/session_tx_test.go @@ -32,45 +32,21 @@ func TestTransaction(t *testing.T) { defer session.Close() err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user2 := Userinfo{Username: "yyy"} _, err = session.Where("(id) = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - fmt.Println(err) - //t.Error(err) - return - } + assert.NoError(t, err) _, err = session.Delete(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - return - } - // panic(err) !nashtsai! should remove this + assert.NoError(t, err) } func TestCombineTransaction(t *testing.T) { @@ -91,38 +67,21 @@ func TestCombineTransaction(t *testing.T) { defer session.Close() err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + assert.NoError(t, err) + user2 := Userinfo{Username: "zzz"} _, err = session.Where("id = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + assert.NoError(t, err) - _, err = session.Exec("delete from userinfo where username = ?", user2.Username) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - } + _, err = session.Exec("delete from "+testEngine.TableNameWithSchema("userinfo")+" where username = ?", user2.Username) + assert.NoError(t, err) err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } func TestCombineTransactionSameMapper(t *testing.T) { @@ -148,45 +107,24 @@ func TestCombineTransactionSameMapper(t *testing.T) { counter() defer counter() + session := testEngine.NewSession() defer session.Close() err := session.Begin() - if err != nil { - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user1 := Userinfo{Username: "xiaoxiao2", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err = session.Insert(&user1) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) user2 := Userinfo{Username: "zzz"} _, err = session.Where("(id) = ?", 0).Update(&user2) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) - _, err = session.Exec("delete from `Userinfo` where `Username` = ?", user2.Username) - if err != nil { - session.Rollback() - t.Error(err) - panic(err) - return - } + _, err = session.Exec("delete from "+testEngine.TableNameWithSchema("`Userinfo`")+" where `Username` = ?", user2.Username) + assert.NoError(t, err) err = session.Commit() - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) } diff --git a/session_update_test.go b/session_update_test.go index d1bc47bc..79205671 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -462,30 +462,18 @@ func TestUpdate1(t *testing.T) { col1 := &UpdateAllCols{Ptr: &s} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateAllCols{col1.Id, true, "", nil} _, err = testEngine.ID(col2.Id).AllCols().Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateAllCols{} has, err = testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !has { err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) @@ -759,7 +747,7 @@ func TestUpdateUpdated(t *testing.T) { func TestUpdateSameMapper(t *testing.T) { assert.NoError(t, prepareEngine()) - oldMapper := testEngine.GetColumnMapper() + oldMapper := testEngine.GetTableMapper() testEngine.UnMapType(rValue(new(Userinfo)).Type()) testEngine.UnMapType(rValue(new(Condi)).Type()) testEngine.UnMapType(rValue(new(Article)).Type()) @@ -786,81 +774,38 @@ func TestUpdateSameMapper(t *testing.T) { var ori Userinfo has, err := testEngine.Get(&ori) - if err != nil { - t.Error(err) - panic(err) - } - if !has { - t.Error(errors.New("not exist")) - panic(errors.New("not exist")) - } + assert.NoError(t, err) + assert.True(t, has) + // update by id user := Userinfo{Username: "xxx", Height: 1.2} cnt, err := testEngine.ID(ori.Uid).Update(&user) - if err != nil { - t.Error(err) - panic(err) - } - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) condi := Condi{"Username": "zzz", "Departname": ""} cnt, err = testEngine.Table(&user).ID(ori.Uid).Update(&condi) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != 1 { - err = errors.New("update not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) cnt, err = testEngine.Update(&Userinfo{Username: "yyy"}, &user) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) total, err := testEngine.Count(&user) - if err != nil { - t.Error(err) - panic(err) - } - - if cnt != total { - err = errors.New("insert not returned 1") - t.Error(err) - panic(err) - return - } + assert.NoError(t, err) + assert.EqualValues(t, cnt, total) err = testEngine.Sync(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) defer func() { err = testEngine.DropTables(&Article{}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) }() a := &Article{0, "1", "2", "3", "4", "5", 2} cnt, err = testEngine.Insert(a) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if cnt != 1 { err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) @@ -875,10 +820,7 @@ func TestUpdateSameMapper(t *testing.T) { } cnt, err = testEngine.ID(a.Id).Update(&Article{Name: "6"}) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if cnt != 1 { err = errors.New(fmt.Sprintf("insert not returned 1 but %d", cnt)) @@ -889,30 +831,18 @@ func TestUpdateSameMapper(t *testing.T) { col1 := &UpdateAllCols{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateAllCols{col1.Id, true, "", nil} _, err = testEngine.ID(col2.Id).AllCols().Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateAllCols{} has, err = testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !has { err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) @@ -931,32 +861,20 @@ func TestUpdateSameMapper(t *testing.T) { { col1 := &UpdateMustCols{} err = testEngine.Sync(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) _, err = testEngine.Insert(col1) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col2 := &UpdateMustCols{col1.Id, true, ""} boolStr := testEngine.GetColumnMapper().Obj2Table("Bool") stringStr := testEngine.GetColumnMapper().Obj2Table("String") _, err = testEngine.ID(col2.Id).MustCols(boolStr, stringStr).Update(col2) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) col3 := &UpdateMustCols{} has, err := testEngine.ID(col2.Id).Get(col3) - if err != nil { - t.Error(err) - panic(err) - } + assert.NoError(t, err) if !has { err = errors.New(fmt.Sprintf("cannot get id %d", col2.Id)) diff --git a/statement.go b/statement.go index 02d73559..2e6a2a6d 100644 --- a/statement.go +++ b/statement.go @@ -225,24 +225,6 @@ func (statement *Statement) setRefValue(v reflect.Value) error { return nil } -// Table tempororily set table name, the parameter could be a string or a pointer of struct -func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { - v := rValue(tableNameOrBean) - t := v.Type() - if t.Kind() == reflect.String { - statement.AltTableName = tableNameOrBean.(string) - } else if t.Kind() == reflect.Struct { - var err error - statement.RefTable, err = statement.Engine.autoMapType(v) - if err != nil { - statement.Engine.logger.Error(err) - return statement - } - statement.AltTableName = statement.Engine.tbName(v) - } - return statement -} - // Auto generating update columnes and values according a struct func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, @@ -743,6 +725,23 @@ func (statement *Statement) Asc(colNames ...string) *Statement { return statement } +// Table tempororily set table name, the parameter could be a string or a pointer of struct +func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { + v := rValue(tableNameOrBean) + t := v.Type() + if t.Kind() == reflect.Struct { + var err error + statement.RefTable, err = statement.Engine.autoMapType(v) + if err != nil { + statement.Engine.logger.Error(err) + return statement + } + } + + statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(tableNameOrBean)) + return statement +} + // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { var buf bytes.Buffer @@ -752,54 +751,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "%v JOIN ", joinOP) } - switch tablename.(type) { - case []string: - t := tablename.([]string) - if len(t) > 1 { - fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) - } else if len(t) == 1 { - fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) - } - case []interface{}: - t := tablename.([]interface{}) - l := len(t) - var table string - if l > 0 { - f := t[0] - switch f.(type) { - case string: - table = f.(string) - case TableName: - table = f.(TableName).TableName() - default: - v := rValue(f) - t := v.Type() - if t.Kind() == reflect.Struct { - fmt.Fprintf(&buf, statement.Engine.tbName(v)) - } else { - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", f))) - } - } - } - if l > 1 { - fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), - statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) - } else if l == 1 { - fmt.Fprintf(&buf, statement.Engine.Quote(table)) - } - case TableName: - fmt.Fprintf(&buf, tablename.(TableName).TableName()) - case string: - fmt.Fprintf(&buf, tablename.(string)) - default: - v := rValue(tablename) - t := v.Type() - if t.Kind() == reflect.Struct { - fmt.Fprintf(&buf, statement.Engine.tbName(v)) - } else { - fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) - } - } + statement.Engine.tbNameNoSchema(&buf, tablename) fmt.Fprintf(&buf, " ON %v", condition) statement.JoinStr = buf.String() @@ -915,7 +867,7 @@ func (statement *Statement) genDelIndexSQL() []string { } sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) if statement.Engine.dialect.IndexOnTable() { - sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) + sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName)) } sqls = append(sqls, sql) }