added Iterate method; added Omit method

This commit is contained in:
Lunny Xiao 2013-10-17 12:50:46 +08:00
parent b618c3870d
commit fc17734817
8 changed files with 248 additions and 65 deletions

View File

@ -1115,6 +1115,19 @@ func testMetaInfo(engine *Engine, t *testing.T) {
} }
} }
func testIterate(engine *Engine, t *testing.T) {
err := engine.Omit("is_man").Iterate(new(Userinfo), func(idx int, bean interface{}) error {
user := bean.(*Userinfo)
fmt.Println(idx, "--", user)
return nil
})
if err != nil {
t.Error(err)
panic(err)
}
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -1195,6 +1208,8 @@ func testAll2(engine *Engine, t *testing.T) {
//testInt32Id(engine, t) //testInt32Id(engine, t)
fmt.Println("-------------- testMetaInfo --------------") fmt.Println("-------------- testMetaInfo --------------")
testMetaInfo(engine, t) testMetaInfo(engine, t)
fmt.Println("-------------- testIterate --------------")
testIterate(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
} }

View File

@ -250,6 +250,12 @@ func (engine *Engine) Cols(columns ...string) *Session {
return session.Cols(columns...) return session.Cols(columns...)
} }
func (engine *Engine) Omit(columns ...string) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Omit(columns...)
}
/*func (engine *Engine) Trans(t string) *Session { /*func (engine *Engine) Trans(t string) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
@ -799,6 +805,12 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
return session.Find(beans, condiBeans...) return session.Find(beans, condiBeans...)
} }
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error {
session := engine.NewSession()
defer session.Close()
return session.Iterate(bean, fun)
}
func (engine *Engine) Count(bean interface{}) (int64, error) { func (engine *Engine) Count(bean interface{}) (int64, error) {
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()

35
iterator.go Normal file
View File

@ -0,0 +1,35 @@
package xorm
import (
"database/sql"
)
type Iterator struct {
session *Session
startId int
rows *sql.Rows
}
func (iter *Iterator) IsValid() bool {
return iter.session != nil && iter.rows != nil
}
/*
func (iter *Iterator) Next(bean interface{}) (bool, error) {
if !iter.IsValid() {
return errors.New("iterator is not valied.")
}
if iter.rows.Next() {
iter.rows.Scan(...)
}
}*/
// close the iterator
func (iter *Iterator) Close() {
if iter.rows != nil {
iter.rows.Close()
}
if iter.session != nil && iter.session.IsAutoClose {
iter.session.Close()
}
}

View File

@ -101,6 +101,10 @@ func TestPostgres2(t *testing.T) {
testCreatedAndUpdated(engine, t) testCreatedAndUpdated(engine, t)
fmt.Println("-------------- testIndexAndUnique --------------") fmt.Println("-------------- testIndexAndUnique --------------")
testIndexAndUnique(engine, t) testIndexAndUnique(engine, t)
fmt.Println("-------------- testMetaInfo --------------")
testMetaInfo(engine, t)
fmt.Println("-------------- testIterate --------------")
testIterate(engine, t)
} }
func BenchmarkPostgresNoCache(t *testing.B) { func BenchmarkPostgresNoCache(t *testing.B) {

View File

@ -82,6 +82,11 @@ func (session *Session) Cols(columns ...string) *Session {
return session return session
} }
func (session *Session) Omit(columns ...string) *Session {
session.Statement.Omit(columns...)
return session
}
// Method NoAutoTime means do not automatically give created field and updated field // Method NoAutoTime means do not automatically give created field and updated field
// the current time on the current session temporarily // the current time on the current session temporarily
func (session *Session) NoAutoTime() *Session { func (session *Session) NoAutoTime() *Session {
@ -663,6 +668,71 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
return nil return nil
} }
type IterFunc func(idx int, bean interface{}) error
func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
err := session.newDb()
if err != nil {
return err
}
defer session.Statement.Init()
if session.IsAutoClose {
defer session.Close()
}
var sql string
var args []interface{}
session.Statement.RefTable = session.Engine.AutoMap(bean)
if session.Statement.RawSQL == "" {
sql, args = session.Statement.genGetSql(bean)
} else {
sql = session.Statement.RawSQL
args = session.Statement.RawParams
}
for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session)
}
session.Engine.LogSQL(sql)
session.Engine.LogSQL(args)
s, err := session.Db.Prepare(sql)
if err != nil {
return err
}
defer s.Close()
rows, err := s.Query(args...)
if err != nil {
return err
}
defer rows.Close()
fields, err := rows.Columns()
if err != nil {
return err
}
t := reflect.Indirect(reflect.ValueOf(bean)).Type()
b := reflect.New(t).Interface()
i := 0
for rows.Next() {
result, err := row2map(rows, fields)
if err == nil {
err = session.scanMapIntoStruct(b, result)
}
if err == nil {
err = fun(i, b)
i = i + 1
}
if err != nil {
return err
}
}
return nil
}
// get retrieve one record from database // get retrieve one record from database
func (session *Session) Get(bean interface{}) (bool, error) { func (session *Session) Get(bean interface{}) (bool, error) {
err := session.newDb() err := session.newDb()
@ -982,29 +1052,14 @@ func (session *Session) DropAll() error {
return nil return nil
} }
func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err error) {
s, err := db.Prepare(sql)
if err != nil {
return nil, err
}
defer s.Close()
res, err := s.Query(params...)
if err != nil {
return nil, err
}
defer res.Close()
fields, err := res.Columns()
if err != nil {
return nil, err
}
for res.Next() {
result := make(map[string][]byte) result := make(map[string][]byte)
var scanResultContainers []interface{} var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var scanResultContainer interface{} var scanResultContainer interface{}
scanResultContainers = append(scanResultContainers, &scanResultContainer) scanResultContainers = append(scanResultContainers, &scanResultContainer)
} }
if err := res.Scan(scanResultContainers...); err != nil { if err := rows.Scan(scanResultContainers...); err != nil {
return nil, err return nil, err
} }
for ii, key := range fields { for ii, key := range fields {
@ -1050,11 +1105,40 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st
//session.Engine.LogError("Unsupported type") //session.Engine.LogError("Unsupported type")
} }
} }
return result, nil
}
func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) {
fields, err := rows.Columns()
if err != nil {
return nil, err
}
for rows.Next() {
result, err := row2map(rows, fields)
if err != nil {
return nil, err
}
resultsSlice = append(resultsSlice, result) resultsSlice = append(resultsSlice, result)
} }
return resultsSlice, nil return resultsSlice, nil
} }
func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
s, err := db.Prepare(sql)
if err != nil {
return nil, err
}
defer s.Close()
rows, err := s.Query(params...)
if err != nil {
return nil, err
}
defer rows.Close()
return rows2maps(rows)
}
func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session) sql = filter.Do(sql, session)

View File

@ -22,6 +22,7 @@ type Statement struct {
HavingStr string HavingStr string
ColumnStr string ColumnStr string
columnMap map[string]bool columnMap map[string]bool
OmitStr string
ConditionStr string ConditionStr string
AltTableName string AltTableName string
RawSQL string RawSQL string
@ -47,6 +48,7 @@ func (statement *Statement) Init() {
statement.GroupByStr = "" statement.GroupByStr = ""
statement.HavingStr = "" statement.HavingStr = ""
statement.ColumnStr = "" statement.ColumnStr = ""
statement.OmitStr = ""
statement.columnMap = make(map[string]bool) statement.columnMap = make(map[string]bool)
statement.ConditionStr = "" statement.ConditionStr = ""
statement.AltTableName = "" statement.AltTableName = ""
@ -82,9 +84,10 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
for _, col := range table.Columns { for _, col := range table.Columns {
fieldValue := col.ValueOf(bean) fieldValue := col.ValueOf(bean)
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface() var val interface{}
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Bool: case reflect.Bool:
val = fieldValue.Interface()
case reflect.String: case reflect.String:
if fieldValue.String() == "" { if fieldValue.String() == "" {
continue continue
@ -92,19 +95,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
// for MyString, should convert to string or panic // for MyString, should convert to string or panic
if fieldType.String() != reflect.String.String() { if fieldType.String() != reflect.String.String() {
val = fieldValue.String() val = fieldValue.String()
} else {
val = fieldValue.Interface()
} }
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 { if fieldValue.Int() == 0 {
continue continue
} }
val = fieldValue.Interface()
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
if fieldValue.Float() == 0.0 { if fieldValue.Float() == 0.0 {
continue continue
} }
val = fieldValue.Interface()
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
if fieldValue.Uint() == 0 { if fieldValue.Uint() == 0 {
continue continue
} }
val = fieldValue.Interface()
case reflect.Struct: case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) { if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time) t := fieldValue.Interface().(time.Time)
@ -121,6 +129,8 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
} else { } else {
continue continue
} }
} else {
val = fieldValue.Interface()
} }
} }
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
@ -219,6 +229,14 @@ func (statement *Statement) Cols(columns ...string) {
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.columnMap[nc] = false
}
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
}
func (statement *Statement) Limit(limit int, start ...int) { func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit statement.LimitN = limit
if len(start) > 0 { if len(start) > 0 {
@ -251,10 +269,16 @@ func (statement *Statement) genColumnStr() string {
table := statement.RefTable table := statement.RefTable
colNames := make([]string, 0) colNames := make([]string, 0)
for _, col := range table.Columns { for _, col := range table.Columns {
if col.MapType != ONLYTODB { if statement.OmitStr != "" {
colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name)) if _, ok := statement.columnMap[col.Name]; ok {
continue
} }
} }
if col.MapType == ONLYTODB {
continue
}
colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name))
}
return strings.Join(colNames, ", ") return strings.Join(colNames, ", ")
} }

View File

@ -313,6 +313,11 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc
continue continue
} }
} }
if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; ok {
continue
}
}
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
args = append(args, time.Now()) args = append(args, time.Now())

View File

@ -43,5 +43,9 @@ func unTitle(src string) string {
return "" return ""
} }
if len(src) == 1 {
return strings.ToLower(string(src[0]))
} else {
return strings.ToLower(string(src[0])) + src[1:] return strings.ToLower(string(src[0])) + src[1:]
}
} }