From 1a64d60e065b9eefe369abfddceb1c8c0c2af2fb Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 14 Nov 2013 23:07:33 +0800 Subject: [PATCH] add Distinct method & fixed Find use with Table --- base_test.go | 33 +++++++++++++++++++++++++++++++++ engine.go | 6 ++++++ session.go | 29 ++++++++++++++++++++--------- statement.go | 13 ++++++++++++- 4 files changed, 71 insertions(+), 10 deletions(-) diff --git a/base_test.go b/base_test.go index 66144b90..2cac2d4c 100644 --- a/base_test.go +++ b/base_test.go @@ -1328,6 +1328,37 @@ func testVersion(engine *Engine, t *testing.T) { } } +func testDistinct(engine *Engine, t *testing.T) { + users := make([]Userinfo, 0) + err := engine.Distinct("departname").Find(&users) + if err != nil { + t.Error(err) + panic(err) + } + if len(users) != 1 { + t.Error(err) + panic(errors.New("should be one record")) + } + + fmt.Println(users) + + type Depart struct { + Departname string + } + + users2 := make([]Depart, 0) + err = engine.Distinct("departname").Table(new(Userinfo)).Find(&users2) + if err != nil { + t.Error(err) + panic(err) + } + if len(users2) != 1 { + t.Error(err) + panic(errors.New("should be one record")) + } + fmt.Println(users2) +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -1414,6 +1445,8 @@ func testAll2(engine *Engine, t *testing.T) { testStrangeName(engine, t) fmt.Println("-------------- testVersion --------------") testVersion(engine, t) + fmt.Println("-------------- testDistinct --------------") + testDistinct(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/engine.go b/engine.go index dd99d50a..88a26385 100644 --- a/engine.go +++ b/engine.go @@ -261,6 +261,12 @@ func (engine *Engine) StoreEngine(storeEngine string) *Session { return session.StoreEngine(storeEngine) } +func (engine *Engine) Distinct(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Distinct(columns...) +} + func (engine *Engine) Cols(columns ...string) *Session { session := engine.NewSession() session.IsAutoClose = true diff --git a/session.go b/session.go index 4f0146af..6bf73274 100644 --- a/session.go +++ b/session.go @@ -94,6 +94,11 @@ func (session *Session) Cols(columns ...string) *Session { return session } +func (session *Session) Distinct(columns ...string) *Session { + session.Statement.Distinct(columns...) + return session +} + func (session *Session) Omit(columns ...string) *Session { session.Statement.Omit(columns...) return session @@ -238,7 +243,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b return errors.New("Expected a pointer to a struct") } - table := session.Engine.Tables[rType(obj)] + table := session.Engine.AutoMapType(rType(obj)) for key, data := range objMap { if _, ok := table.Columns[key]; !ok { @@ -848,18 +853,22 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() var table *Table - if sliceElementType.Kind() == reflect.Ptr { - if sliceElementType.Elem().Kind() == reflect.Struct { - table = session.Engine.AutoMapType(sliceElementType.Elem()) + if session.Statement.RefTable == nil { + if sliceElementType.Kind() == reflect.Ptr { + if sliceElementType.Elem().Kind() == reflect.Struct { + table = session.Engine.AutoMapType(sliceElementType.Elem()) + } else { + return errors.New("slice type") + } + } else if sliceElementType.Kind() == reflect.Struct { + table = session.Engine.AutoMapType(sliceElementType) } else { return errors.New("slice type") } - } else if sliceElementType.Kind() == reflect.Struct { - table = session.Engine.AutoMapType(sliceElementType) + session.Statement.RefTable = table } else { - return errors.New("slice type") + table = session.Statement.RefTable } - session.Statement.RefTable = table if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true) @@ -881,7 +890,9 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) args = session.Statement.RawParams } - if table.Cacher != nil && session.Statement.UseCache { + if table.Cacher != nil && + session.Statement.UseCache && + !session.Statement.IsDistinct { err = session.cacheFind(sliceElementType, sql, rowsSlicePtr, args...) if err != ErrCacheFailed { return err diff --git a/statement.go b/statement.go index acadf25b..6d2f3cc7 100644 --- a/statement.go +++ b/statement.go @@ -34,6 +34,7 @@ type Statement struct { BeanArgs []interface{} UseCache bool UseAutoTime bool + IsDistinct bool } func (statement *Statement) Init() { @@ -57,6 +58,7 @@ func (statement *Statement) Init() { statement.BeanArgs = make([]interface{}, 0) statement.UseCache = statement.Engine.UseCache statement.UseAutoTime = true + statement.IsDistinct = false } func (statement *Statement) Sql(querystring string, args ...interface{}) { @@ -246,6 +248,11 @@ func col2NewCols(columns ...string) []string { return newColumns } +func (statement *Statement) Distinct(columns ...string) { + statement.IsDistinct = true + statement.Cols(columns...) +} + func (statement *Statement) Cols(columns ...string) { newColumns := col2NewCols(columns...) for _, nc := range newColumns { @@ -441,7 +448,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) statement.GroupByStr = columnStr } - a = fmt.Sprintf("SELECT %v FROM %v", columnStr, + var distinct string + if statement.IsDistinct { + distinct = "DISTINCT " + } + a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr, statement.Engine.Quote(statement.TableName())) if statement.JoinStr != "" { a = fmt.Sprintf("%v %v", a, statement.JoinStr)