add Distinct method & fixed Find use with Table

This commit is contained in:
Lunny Xiao 2013-11-14 23:07:33 +08:00
parent 2a6991886c
commit 1a64d60e06
4 changed files with 71 additions and 10 deletions

View File

@ -1328,6 +1328,37 @@ func testVersion(engine *Engine, t *testing.T) {
} }
} }
func testDistinct(engine *Engine, t *testing.T) {
users := make([]Userinfo, 0)
err := engine.Distinct("departname").Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
if len(users) != 1 {
t.Error(err)
panic(errors.New("should be one record"))
}
fmt.Println(users)
type Depart struct {
Departname string
}
users2 := make([]Depart, 0)
err = engine.Distinct("departname").Table(new(Userinfo)).Find(&users2)
if err != nil {
t.Error(err)
panic(err)
}
if len(users2) != 1 {
t.Error(err)
panic(errors.New("should be one record"))
}
fmt.Println(users2)
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -1414,6 +1445,8 @@ func testAll2(engine *Engine, t *testing.T) {
testStrangeName(engine, t) testStrangeName(engine, t)
fmt.Println("-------------- testVersion --------------") fmt.Println("-------------- testVersion --------------")
testVersion(engine, t) testVersion(engine, t)
fmt.Println("-------------- testDistinct --------------")
testDistinct(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
} }

View File

@ -261,6 +261,12 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session {
return session.StoreEngine(storeEngine) return session.StoreEngine(storeEngine)
} }
func (engine *Engine) Distinct(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Distinct(columns...)
}
func (engine *Engine) Cols(columns ...string) *Session { func (engine *Engine) Cols(columns ...string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true

View File

@ -94,6 +94,11 @@ func (session *Session) Cols(columns ...string) *Session {
return session return session
} }
func (session *Session) Distinct(columns ...string) *Session {
session.Statement.Distinct(columns...)
return session
}
func (session *Session) Omit(columns ...string) *Session { func (session *Session) Omit(columns ...string) *Session {
session.Statement.Omit(columns...) session.Statement.Omit(columns...)
return session return session
@ -238,7 +243,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
return errors.New("Expected a pointer to a struct") return errors.New("Expected a pointer to a struct")
} }
table := session.Engine.Tables[rType(obj)] table := session.Engine.AutoMapType(rType(obj))
for key, data := range objMap { for key, data := range objMap {
if _, ok := table.Columns[key]; !ok { if _, ok := table.Columns[key]; !ok {
@ -848,18 +853,22 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var table *Table var table *Table
if sliceElementType.Kind() == reflect.Ptr { if session.Statement.RefTable == nil {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Kind() == reflect.Ptr {
table = session.Engine.AutoMapType(sliceElementType.Elem()) if sliceElementType.Elem().Kind() == reflect.Struct {
table = session.Engine.AutoMapType(sliceElementType.Elem())
} else {
return errors.New("slice type")
}
} else if sliceElementType.Kind() == reflect.Struct {
table = session.Engine.AutoMapType(sliceElementType)
} else { } else {
return errors.New("slice type") return errors.New("slice type")
} }
} else if sliceElementType.Kind() == reflect.Struct { session.Statement.RefTable = table
table = session.Engine.AutoMapType(sliceElementType)
} else { } else {
return errors.New("slice type") table = session.Statement.RefTable
} }
session.Statement.RefTable = table
if len(condiBean) > 0 { if len(condiBean) > 0 {
colNames, args := buildConditions(session.Engine, table, condiBean[0], true) colNames, args := buildConditions(session.Engine, table, condiBean[0], true)
@ -881,7 +890,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if table.Cacher != nil && session.Statement.UseCache { if table.Cacher != nil &&
session.Statement.UseCache &&
!session.Statement.IsDistinct {
err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return err return err

View File

@ -34,6 +34,7 @@ type Statement struct {
BeanArgs []interface{} BeanArgs []interface{}
UseCache bool UseCache bool
UseAutoTime bool UseAutoTime bool
IsDistinct bool
} }
func (statement *Statement) Init() { func (statement *Statement) Init() {
@ -57,6 +58,7 @@ func (statement *Statement) Init() {
statement.BeanArgs = make([]interface{}, 0) statement.BeanArgs = make([]interface{}, 0)
statement.UseCache = statement.Engine.UseCache statement.UseCache = statement.Engine.UseCache
statement.UseAutoTime = true statement.UseAutoTime = true
statement.IsDistinct = false
} }
func (statement *Statement) Sql(querystring string, args ...interface{}) { func (statement *Statement) Sql(querystring string, args ...interface{}) {
@ -246,6 +248,11 @@ func col2NewCols(columns ...string) []string {
return newColumns return newColumns
} }
func (statement *Statement) Distinct(columns ...string) {
statement.IsDistinct = true
statement.Cols(columns...)
}
func (statement *Statement) Cols(columns ...string) { func (statement *Statement) Cols(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
@ -441,7 +448,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1))
statement.GroupByStr = columnStr statement.GroupByStr = columnStr
} }
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, var distinct string
if statement.IsDistinct {
distinct = "DISTINCT "
}
a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr,
statement.Engine.Quote(statement.TableName())) statement.Engine.Quote(statement.TableName()))
if statement.JoinStr != "" { if statement.JoinStr != "" {
a = fmt.Sprintf("%v %v", a, statement.JoinStr) a = fmt.Sprintf("%v %v", a, statement.JoinStr)