Merge branch 'master' of github.com:go-xorm/xorm

This commit is contained in:
Lunny Xiao 2018-08-09 22:38:41 +08:00
commit e88d320934
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
8 changed files with 53 additions and 22 deletions

View File

@ -177,6 +177,14 @@ func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr() return engine.dialect.QuoteStr()
} }
func (engine *Engine) quoteColumns(columnStr string) string {
columns := strings.Split(columnStr, ",")
for i := 0; i < len(columns); i++ {
columns[i] = engine.Quote(strings.TrimSpace(columns[i]))
}
return strings.Join(columns, ",")
}
// Quote Use QuoteStr quote the string sql // Quote Use QuoteStr quote the string sql
func (engine *Engine) Quote(value string) string { func (engine *Engine) Quote(value string) string {
value = strings.TrimSpace(value) value = strings.TrimSpace(value)
@ -1333,10 +1341,10 @@ func (engine *Engine) DropIndexes(bean interface{}) error {
} }
// Exec raw sql // Exec raw sql
func (engine *Engine) Exec(sql string, args ...interface{}) (sql.Result, error) { func (engine *Engine) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
return session.Exec(sql, args...) return session.Exec(sqlorArgs...)
} }
// Query a raw sql and return records as []map[string][]byte // Query a raw sql and return records as []map[string][]byte

View File

@ -91,7 +91,7 @@ func main() {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL(true)
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("test end") fmt.Println("test end")

View File

@ -27,7 +27,7 @@ type Interface interface {
Delete(interface{}) (int64, error) Delete(interface{}) (int64, error)
Distinct(columns ...string) *Session Distinct(columns ...string) *Session
DropIndexes(bean interface{}) error DropIndexes(bean interface{}) error
Exec(string, ...interface{}) (sql.Result, error) Exec(sqlOrAgrs ...interface{}) (sql.Result, error)
Exist(bean ...interface{}) (bool, error) Exist(bean ...interface{}) (bool, error)
Find(interface{}, ...interface{}) error Find(interface{}, ...interface{}) error
FindAndCount(interface{}, ...interface{}) (int64, error) FindAndCount(interface{}, ...interface{}) (int64, error)

View File

@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
if session.statement.JoinStr == "" { if session.statement.JoinStr == "" {
if columnStr == "" { if columnStr == "" {
if session.statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else { } else {
columnStr = session.statement.genColumnStr() columnStr = session.statement.genColumnStr()
} }
@ -143,7 +143,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
} else { } else {
if columnStr == "" { if columnStr == "" {
if session.statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else { } else {
columnStr = "*" columnStr = "*"
} }

View File

@ -268,6 +268,15 @@ func TestOrder(t *testing.T) {
fmt.Println(users2) fmt.Println(users2)
} }
func TestGroupBy(t *testing.T) {
assert.NoError(t, prepareEngine())
assertSync(t, new(Userinfo))
users := make([]Userinfo, 0)
err := testEngine.GroupBy("id, username").Find(&users)
assert.NoError(t, err)
}
func TestHaving(t *testing.T) { func TestHaving(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
assertSync(t, new(Userinfo)) assertSync(t, new(Userinfo))

View File

@ -17,17 +17,7 @@ import (
func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) { func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) {
if len(sqlorArgs) > 0 { if len(sqlorArgs) > 0 {
switch sqlorArgs[0].(type) { return convertSQLOrArgs(sqlorArgs...)
case string:
return sqlorArgs[0].(string), sqlorArgs[1:], nil
case *builder.Builder:
return sqlorArgs[0].(*builder.Builder).ToSQL()
case builder.Builder:
bd := sqlorArgs[0].(builder.Builder)
return bd.ToSQL()
default:
return "", nil, ErrUnSupportedType
}
} }
if session.statement.RawSQL != "" { if session.statement.RawSQL != "" {
@ -45,7 +35,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
if session.statement.JoinStr == "" { if session.statement.JoinStr == "" {
if columnStr == "" { if columnStr == "" {
if session.statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else { } else {
columnStr = session.statement.genColumnStr() columnStr = session.statement.genColumnStr()
} }
@ -53,7 +43,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa
} else { } else {
if columnStr == "" { if columnStr == "" {
if session.statement.GroupByStr != "" { if session.statement.GroupByStr != "" {
columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) columnStr = session.engine.quoteColumns(session.statement.GroupByStr)
} else { } else {
columnStr = "*" columnStr = "*"
} }

View File

@ -9,6 +9,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/go-xorm/builder"
"github.com/go-xorm/core" "github.com/go-xorm/core"
) )
@ -193,11 +194,34 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
return session.DB().Exec(sqlStr, args...) return session.DB().Exec(sqlStr, args...)
} }
func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {
switch sqlorArgs[0].(type) {
case string:
return sqlorArgs[0].(string), sqlorArgs[1:], nil
case *builder.Builder:
return sqlorArgs[0].(*builder.Builder).ToSQL()
case builder.Builder:
bd := sqlorArgs[0].(builder.Builder)
return bd.ToSQL()
}
return "", nil, ErrUnSupportedType
}
// Exec raw sql // Exec raw sql
func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sqlorArgs ...interface{}) (sql.Result, error) {
if session.isAutoClose { if session.isAutoClose {
defer session.Close() defer session.Close()
} }
if len(sqlorArgs) == 0 {
return nil, ErrUnSupportedType
}
sqlStr, args, err := convertSQLOrArgs(sqlorArgs...)
if err != nil {
return nil, err
}
return session.exec(sqlStr, args...) return session.exec(sqlStr, args...)
} }

View File

@ -933,7 +933,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
if len(statement.JoinStr) == 0 { if len(statement.JoinStr) == 0 {
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
} else { } else {
columnStr = statement.genColumnStr() columnStr = statement.genColumnStr()
} }
@ -941,7 +941,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{},
} else { } else {
if len(columnStr) == 0 { if len(columnStr) == 0 {
if len(statement.GroupByStr) > 0 { if len(statement.GroupByStr) > 0 {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
} }
} }
} }