improve the interface of EngineInterface

This commit is contained in:
Lunny Xiao 2018-04-10 09:27:44 +08:00
parent a21c00771f
commit ea6ebc3514
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
16 changed files with 61 additions and 61 deletions

View File

@ -819,7 +819,7 @@ func (engine *Engine) TableInfo(bean interface{}) *Table {
if err != nil {
engine.logger.Error(err)
}
return &Table{tb, engine.tableName(bean)}
return &Table{tb, engine.TableName(bean)}
}
func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) {
@ -1134,7 +1134,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id string) error {
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tableName(bean)
tableName := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err
@ -1158,7 +1158,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
if t.Kind() != reflect.Struct {
return errors.New("error params")
}
tableName := engine.tableName(bean)
tableName := engine.TableName(bean)
table, err := engine.autoMapType(v)
if err != nil {
return err

View File

@ -13,7 +13,7 @@ import (
)
// TableNameWithSchema will automatically add schema prefix on table name
func (engine *Engine) TableNameWithSchema(v string) string {
func (engine *Engine) tbNameWithSchema(v string) string {
// Add schema name as prefix of table name.
// Only for postgres database.
if engine.dialect.DBType() == core.POSTGRES &&
@ -25,8 +25,23 @@ func (engine *Engine) TableNameWithSchema(v string) string {
return v
}
func (engine *Engine) tableName(bean interface{}) string {
return engine.TableNameWithSchema(engine.tbNameNoSchema(bean))
// TableName returns table name with schema prefix if has
func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string {
tbName := engine.tbNameNoSchema(bean)
if len(includeSchema) > 0 && includeSchema[0] {
tbName = engine.tbNameWithSchema(tbName)
}
return tbName
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
func (engine *Engine) tbNameForMap(v reflect.Value) string {

View File

@ -95,7 +95,7 @@ type EngineInterface interface {
Sync2(...interface{}) error
StoreEngine(storeEngine string) *Session
TableInfo(bean interface{}) *Table
TableNameWithSchema(string) string
TableName(interface{}, ...bool) string
UnMapType(reflect.Type)
}

View File

@ -54,10 +54,7 @@ func TestRows(t *testing.T) {
}
assert.EqualValues(t, 1, cnt)
var tbName = testEngine.Quote("user_rows")
if testEngine.Dialect().URI().Schema != "" {
tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName
}
var tbName = testEngine.Quote(testEngine.TableName(user, true))
rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
assert.NoError(t, err)
defer rows2.Close()

View File

@ -828,15 +828,6 @@ func (session *Session) LastSQL() (string, []interface{}) {
return session.lastSQL, session.lastSQLArgs
}
// tbName get some table's table name
func (session *Session) tbNameNoSchema(table *core.Table) string {
if len(session.statement.AltTableName) > 0 {
return session.statement.AltTableName
}
return table.Name
}
// Unscoped always disable struct tag "deleted"
func (session *Session) Unscoped() *Session {
session.statement.Unscoped()

View File

@ -54,11 +54,11 @@ func TestExistStruct(t *testing.T) {
assert.NoError(t, err)
assert.False(t, has)
has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test1").Exist()
has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test1").Exist()
assert.NoError(t, err)
assert.True(t, has)
has, err = testEngine.SQL("select * from "+testEngine.TableNameWithSchema("record_exist")+" where name = ?", "test2").Exist()
has, err = testEngine.SQL("select * from "+testEngine.TableName("record_exist", true)+" where name = ?", "test2").Exist()
assert.NoError(t, err)
assert.False(t, has)

View File

@ -102,11 +102,7 @@ func TestFind(t *testing.T) {
}
users2 := make([]Userinfo, 0)
userinfo := testEngine.GetTableMapper().Obj2Table("Userinfo")
var tbName = testEngine.Quote(userinfo)
if testEngine.Dialect().URI().Schema != "" {
tbName = testEngine.Quote(testEngine.Dialect().URI().Schema) + "." + tbName
}
var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true))
err = testEngine.SQL("select * from " + tbName).Find(&users2)
assert.NoError(t, err)
}

View File

@ -84,7 +84,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money))
var money2 float64
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableNameWithSchema("get_var") + " LIMIT 1").Get(&money2)
has, err = testEngine.SQL("SELECT money FROM " + testEngine.TableName("get_var", true) + " LIMIT 1").Get(&money2)
assert.NoError(t, err)
assert.Equal(t, true, has)
assert.Equal(t, "1.5", fmt.Sprintf("%.1f", money2))

View File

@ -36,7 +36,7 @@ func TestQueryString(t *testing.T) {
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var2"))
records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var2", true))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
@ -62,7 +62,7 @@ func TestQueryString2(t *testing.T) {
_, err := testEngine.Insert(data)
assert.NoError(t, err)
records, err := testEngine.QueryString("select * from " + testEngine.TableNameWithSchema("get_var3"))
records, err := testEngine.QueryString("select * from " + testEngine.TableName("get_var3", true))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 2, len(records[0]))
@ -127,7 +127,7 @@ func TestQueryInterface(t *testing.T) {
_, err := testEngine.InsertOne(data)
assert.NoError(t, err)
records, err := testEngine.QueryInterface("select * from " + testEngine.TableNameWithSchema("get_var_interface"))
records, err := testEngine.QueryInterface("select * from " + testEngine.TableName("get_var_interface", true))
assert.NoError(t, err)
assert.Equal(t, 1, len(records))
assert.Equal(t, 5, len(records[0]))
@ -181,7 +181,7 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, err)
assertResult(t, results)
results, err = testEngine.SQL("select * from " + testEngine.TableNameWithSchema("query_no_params")).Query()
results, err = testEngine.SQL("select * from " + testEngine.TableName("query_no_params", true)).Query()
assert.NoError(t, err)
assertResult(t, results)
}
@ -226,7 +226,7 @@ func TestQueryWithBuilder(t *testing.T) {
assert.EqualValues(t, 3000, money)
}
results, err := testEngine.Query(builder.Select("*").From(testEngine.TableNameWithSchema("query_with_builder")))
results, err := testEngine.Query(builder.Select("*").From(testEngine.TableName("query_with_builder", true)))
assert.NoError(t, err)
assertResult(t, results)
}

View File

@ -21,13 +21,13 @@ func TestExecAndQuery(t *testing.T) {
assert.NoError(t, testEngine.Sync2(new(UserinfoQuery)))
res, err := testEngine.Exec("INSERT INTO "+testEngine.TableNameWithSchema("`userinfo_query`")+" (uid, name) VALUES (?, ?)", 1, "user")
res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_query`", true)+" (uid, name) VALUES (?, ?)", 1, "user")
assert.NoError(t, err)
cnt, err := res.RowsAffected()
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
results, err := testEngine.Query("select * from " + testEngine.TableNameWithSchema("userinfo_query"))
results, err := testEngine.Query("select * from " + testEngine.TableName("userinfo_query", true))
assert.NoError(t, err)
assert.EqualValues(t, 1, len(results))
id, err := strconv.Atoi(string(results[0]["uid"]))

View File

@ -134,7 +134,7 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
}
if needDrop {
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableNameWithSchema(tableName))
sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
_, err := session.exec(sqlStr)
return err
}
@ -168,7 +168,7 @@ func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
func (session *Session) isTableEmpty(tableName string) (bool, error) {
var total int64
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.TableNameWithSchema(session.engine.Quote(tableName)))
sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
err := session.queryRow(sqlStr).Scan(&total)
if err != nil {
if err == sql.ErrNoRows {
@ -249,7 +249,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
}
structTables = append(structTables, table)
tbName := session.tbNameNoSchema(table)
tbNameWithSchema := engine.TableNameWithSchema(tbName)
tbNameWithSchema := engine.TableName(tbName, true)
var oriTable *core.Table
for _, tb := range tables {
@ -413,7 +413,7 @@ func (session *Session) Sync2(beans ...interface{}) error {
for _, colName := range table.ColumnsSeq() {
if oriTable.GetColumn(colName) == nil {
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableNameWithSchema(table.Name), colName)
engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
}
}
}

View File

@ -153,7 +153,7 @@ func TestSQLCount(t *testing.T) {
assertSync(t, new(UserinfoCount2), new(UserinfoBooks))
total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableNameWithSchema("userinfo_count2")).
total, err := testEngine.SQL("SELECT count(id) FROM " + testEngine.TableName("userinfo_count2", true)).
Count()
assert.NoError(t, err)
assert.EqualValues(t, 0, total)

View File

@ -77,7 +77,7 @@ func TestCombineTransaction(t *testing.T) {
_, err = session.Where("id = ?", 0).Update(&user2)
assert.NoError(t, err)
_, err = session.Exec("delete from "+testEngine.TableNameWithSchema("userinfo")+" where username = ?", user2.Username)
_, err = session.Exec("delete from "+testEngine.TableName("userinfo", true)+" where username = ?", user2.Username)
assert.NoError(t, err)
err = session.Commit()
@ -122,7 +122,7 @@ func TestCombineTransactionSameMapper(t *testing.T) {
_, err = session.Where("(id) = ?", 0).Update(&user2)
assert.NoError(t, err)
_, err = session.Exec("delete from "+testEngine.TableNameWithSchema("`Userinfo`")+" where `Username` = ?", user2.Username)
_, err = session.Exec("delete from "+testEngine.TableName("`Userinfo`", true)+" where `Username` = ?", user2.Username)
assert.NoError(t, err)
err = session.Commit()

View File

@ -221,7 +221,7 @@ func (statement *Statement) setRefValue(v reflect.Value) error {
if err != nil {
return err
}
statement.tableName = statement.Engine.tableName(v.Interface())
statement.tableName = statement.Engine.TableName(v.Interface(), true)
return nil
}
@ -231,7 +231,7 @@ func (statement *Statement) setRefBean(bean interface{}) error {
if err != nil {
return err
}
statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(bean))
statement.tableName = statement.Engine.TableName(bean, true)
return nil
}
@ -748,7 +748,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
}
}
statement.AltTableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tableNameOrBean))
statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
return statement
}
@ -761,7 +761,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
fmt.Fprintf(&buf, "%v JOIN ", joinOP)
}
tbName := statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchema(tablename))
tbName := statement.Engine.TableName(tablename, true)
fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
statement.JoinStr = buf.String()
@ -879,7 +879,7 @@ func (statement *Statement) genDelIndexSQL() []string {
} else if index.Type == core.IndexType {
rIdxName = indexName(idxPrefixName, idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.TableNameWithSchema(statement.Engine.Quote(rIdxName)))
sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
if statement.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
}

View File

@ -202,9 +202,9 @@ func TestExtends(t *testing.T) {
var info UserAndDetail
qt := testEngine.Quote
ui := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userinfo"))
ud := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("Userdetail"))
uiid := testEngine.GetTableMapper().Obj2Table("Id")
ui := testEngine.TableName(new(Userinfo), true)
ud := testEngine.TableName(&detail, true)
uiid := testEngine.GetColumnMapper().Obj2Table("Id")
udid := "detail_id"
sql := fmt.Sprintf("select * from %s, %s where %s.%s = %s.%s",
qt(ui), qt(ud), qt(ui), qt(udid), qt(ud), qt(uiid))
@ -339,9 +339,9 @@ func TestExtends2(t *testing.T) {
var mapper = testEngine.GetTableMapper().Obj2Table
var quote = testEngine.Quote
userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
msgTableName := quote(testEngine.TableName(mapper("Message"), true))
list := make([]Message, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
@ -405,9 +405,9 @@ func TestExtends3(t *testing.T) {
var mapper = testEngine.GetTableMapper().Obj2Table
var quote = testEngine.Quote
userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
msgTableName := quote(testEngine.TableName(mapper("Message"), true))
list := make([]MessageExtend3, 0)
err = testEngine.Table(msgTableName).Join("LEFT", []string{userTableName, "sender"}, "`sender`.`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").
@ -490,9 +490,9 @@ func TestExtends4(t *testing.T) {
var mapper = testEngine.GetTableMapper().Obj2Table
var quote = testEngine.Quote
userTableName := quote(testEngine.TableNameWithSchema(mapper("MessageUser")))
typeTableName := quote(testEngine.TableNameWithSchema(mapper("MessageType")))
msgTableName := quote(testEngine.TableNameWithSchema(mapper("Message")))
userTableName := quote(testEngine.TableName(mapper("MessageUser"), true))
typeTableName := quote(testEngine.TableName(mapper("MessageType"), true))
msgTableName := quote(testEngine.TableName(mapper("Message"), true))
list := make([]MessageExtend4, 0)
err = testEngine.Table(msgTableName).Join("LEFT", userTableName, userTableName+".`"+mapper("Id")+"`="+msgTableName+".`"+mapper("Uid")+"`").

View File

@ -301,10 +301,11 @@ type UserCus struct {
func TestCustomType2(t *testing.T) {
assert.NoError(t, prepareEngine())
err := testEngine.CreateTables(&UserCus{})
var uc UserCus
err := testEngine.CreateTables(&uc)
assert.NoError(t, err)
tableName := testEngine.TableNameWithSchema(testEngine.GetTableMapper().Obj2Table("UserCus"))
tableName := testEngine.TableName(&uc, true)
_, err = testEngine.Exec("delete from " + testEngine.Quote(tableName))
assert.NoError(t, err)