added Iterate method; added Omit method
This commit is contained in:
parent
b618c3870d
commit
fc17734817
15
base_test.go
15
base_test.go
|
@ -1115,6 +1115,19 @@ func testMetaInfo(engine *Engine, t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testIterate(engine *Engine, t *testing.T) {
|
||||
err := engine.Omit("is_man").Iterate(new(Userinfo), func(idx int, bean interface{}) error {
|
||||
user := bean.(*Userinfo)
|
||||
fmt.Println(idx, "--", user)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func testAll(engine *Engine, t *testing.T) {
|
||||
fmt.Println("-------------- directCreateTable --------------")
|
||||
directCreateTable(engine, t)
|
||||
|
@ -1195,6 +1208,8 @@ func testAll2(engine *Engine, t *testing.T) {
|
|||
//testInt32Id(engine, t)
|
||||
fmt.Println("-------------- testMetaInfo --------------")
|
||||
testMetaInfo(engine, t)
|
||||
fmt.Println("-------------- testIterate --------------")
|
||||
testIterate(engine, t)
|
||||
fmt.Println("-------------- transaction --------------")
|
||||
transaction(engine, t)
|
||||
}
|
||||
|
|
12
engine.go
12
engine.go
|
@ -250,6 +250,12 @@ func (engine *Engine) Cols(columns ...string) *Session {
|
|||
return session.Cols(columns...)
|
||||
}
|
||||
|
||||
func (engine *Engine) Omit(columns ...string) *Session {
|
||||
session := engine.NewSession()
|
||||
session.IsAutoClose = true
|
||||
return session.Omit(columns...)
|
||||
}
|
||||
|
||||
/*func (engine *Engine) Trans(t string) *Session {
|
||||
session := engine.NewSession()
|
||||
session.IsAutoClose = true
|
||||
|
@ -799,6 +805,12 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error {
|
|||
return session.Find(beans, condiBeans...)
|
||||
}
|
||||
|
||||
func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error {
|
||||
session := engine.NewSession()
|
||||
defer session.Close()
|
||||
return session.Iterate(bean, fun)
|
||||
}
|
||||
|
||||
func (engine *Engine) Count(bean interface{}) (int64, error) {
|
||||
session := engine.NewSession()
|
||||
defer session.Close()
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
package xorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Iterator struct {
|
||||
session *Session
|
||||
startId int
|
||||
rows *sql.Rows
|
||||
}
|
||||
|
||||
func (iter *Iterator) IsValid() bool {
|
||||
return iter.session != nil && iter.rows != nil
|
||||
}
|
||||
|
||||
/*
|
||||
func (iter *Iterator) Next(bean interface{}) (bool, error) {
|
||||
if !iter.IsValid() {
|
||||
return errors.New("iterator is not valied.")
|
||||
}
|
||||
if iter.rows.Next() {
|
||||
iter.rows.Scan(...)
|
||||
}
|
||||
}*/
|
||||
|
||||
// close the iterator
|
||||
func (iter *Iterator) Close() {
|
||||
if iter.rows != nil {
|
||||
iter.rows.Close()
|
||||
}
|
||||
if iter.session != nil && iter.session.IsAutoClose {
|
||||
iter.session.Close()
|
||||
}
|
||||
}
|
|
@ -101,6 +101,10 @@ func TestPostgres2(t *testing.T) {
|
|||
testCreatedAndUpdated(engine, t)
|
||||
fmt.Println("-------------- testIndexAndUnique --------------")
|
||||
testIndexAndUnique(engine, t)
|
||||
fmt.Println("-------------- testMetaInfo --------------")
|
||||
testMetaInfo(engine, t)
|
||||
fmt.Println("-------------- testIterate --------------")
|
||||
testIterate(engine, t)
|
||||
}
|
||||
|
||||
func BenchmarkPostgresNoCache(t *testing.B) {
|
||||
|
|
118
session.go
118
session.go
|
@ -82,6 +82,11 @@ func (session *Session) Cols(columns ...string) *Session {
|
|||
return session
|
||||
}
|
||||
|
||||
func (session *Session) Omit(columns ...string) *Session {
|
||||
session.Statement.Omit(columns...)
|
||||
return session
|
||||
}
|
||||
|
||||
// Method NoAutoTime means do not automatically give created field and updated field
|
||||
// the current time on the current session temporarily
|
||||
func (session *Session) NoAutoTime() *Session {
|
||||
|
@ -663,6 +668,71 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
|
|||
return nil
|
||||
}
|
||||
|
||||
type IterFunc func(idx int, bean interface{}) error
|
||||
|
||||
func (session *Session) Iterate(bean interface{}, fun IterFunc) error {
|
||||
err := session.newDb()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer session.Statement.Init()
|
||||
if session.IsAutoClose {
|
||||
defer session.Close()
|
||||
}
|
||||
|
||||
var sql string
|
||||
var args []interface{}
|
||||
session.Statement.RefTable = session.Engine.AutoMap(bean)
|
||||
if session.Statement.RawSQL == "" {
|
||||
sql, args = session.Statement.genGetSql(bean)
|
||||
} else {
|
||||
sql = session.Statement.RawSQL
|
||||
args = session.Statement.RawParams
|
||||
}
|
||||
|
||||
for _, filter := range session.Engine.Filters {
|
||||
sql = filter.Do(sql, session)
|
||||
}
|
||||
|
||||
session.Engine.LogSQL(sql)
|
||||
session.Engine.LogSQL(args)
|
||||
|
||||
s, err := session.Db.Prepare(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
rows, err := s.Query(args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
fields, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t := reflect.Indirect(reflect.ValueOf(bean)).Type()
|
||||
b := reflect.New(t).Interface()
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
result, err := row2map(rows, fields)
|
||||
if err == nil {
|
||||
err = session.scanMapIntoStruct(b, result)
|
||||
}
|
||||
if err == nil {
|
||||
err = fun(i, b)
|
||||
i = i + 1
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// get retrieve one record from database
|
||||
func (session *Session) Get(bean interface{}) (bool, error) {
|
||||
err := session.newDb()
|
||||
|
@ -982,29 +1052,14 @@ func (session *Session) DropAll() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
|
||||
s, err := db.Prepare(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.Close()
|
||||
res, err := s.Query(params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Close()
|
||||
fields, err := res.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for res.Next() {
|
||||
func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err error) {
|
||||
result := make(map[string][]byte)
|
||||
var scanResultContainers []interface{}
|
||||
for i := 0; i < len(fields); i++ {
|
||||
var scanResultContainer interface{}
|
||||
scanResultContainers = append(scanResultContainers, &scanResultContainer)
|
||||
}
|
||||
if err := res.Scan(scanResultContainers...); err != nil {
|
||||
if err := rows.Scan(scanResultContainers...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for ii, key := range fields {
|
||||
|
@ -1050,11 +1105,40 @@ func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[st
|
|||
//session.Engine.LogError("Unsupported type")
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) {
|
||||
fields, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
result, err := row2map(rows, fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultsSlice = append(resultsSlice, result)
|
||||
}
|
||||
|
||||
return resultsSlice, nil
|
||||
}
|
||||
|
||||
func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
|
||||
s, err := db.Prepare(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer s.Close()
|
||||
rows, err := s.Query(params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return rows2maps(rows)
|
||||
}
|
||||
|
||||
func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
|
||||
for _, filter := range session.Engine.Filters {
|
||||
sql = filter.Do(sql, session)
|
||||
|
|
30
statement.go
30
statement.go
|
@ -22,6 +22,7 @@ type Statement struct {
|
|||
HavingStr string
|
||||
ColumnStr string
|
||||
columnMap map[string]bool
|
||||
OmitStr string
|
||||
ConditionStr string
|
||||
AltTableName string
|
||||
RawSQL string
|
||||
|
@ -47,6 +48,7 @@ func (statement *Statement) Init() {
|
|||
statement.GroupByStr = ""
|
||||
statement.HavingStr = ""
|
||||
statement.ColumnStr = ""
|
||||
statement.OmitStr = ""
|
||||
statement.columnMap = make(map[string]bool)
|
||||
statement.ConditionStr = ""
|
||||
statement.AltTableName = ""
|
||||
|
@ -82,9 +84,10 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
|
|||
for _, col := range table.Columns {
|
||||
fieldValue := col.ValueOf(bean)
|
||||
fieldType := reflect.TypeOf(fieldValue.Interface())
|
||||
val := fieldValue.Interface()
|
||||
var val interface{}
|
||||
switch fieldType.Kind() {
|
||||
case reflect.Bool:
|
||||
val = fieldValue.Interface()
|
||||
case reflect.String:
|
||||
if fieldValue.String() == "" {
|
||||
continue
|
||||
|
@ -92,19 +95,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
|
|||
// for MyString, should convert to string or panic
|
||||
if fieldType.String() != reflect.String.String() {
|
||||
val = fieldValue.String()
|
||||
} else {
|
||||
val = fieldValue.Interface()
|
||||
}
|
||||
case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
|
||||
if fieldValue.Int() == 0 {
|
||||
continue
|
||||
}
|
||||
val = fieldValue.Interface()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
if fieldValue.Float() == 0.0 {
|
||||
continue
|
||||
}
|
||||
val = fieldValue.Interface()
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
if fieldValue.Uint() == 0 {
|
||||
continue
|
||||
}
|
||||
val = fieldValue.Interface()
|
||||
case reflect.Struct:
|
||||
if fieldType == reflect.TypeOf(time.Now()) {
|
||||
t := fieldValue.Interface().(time.Time)
|
||||
|
@ -121,6 +129,8 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string,
|
|||
} else {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
val = fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
|
@ -219,6 +229,14 @@ func (statement *Statement) Cols(columns ...string) {
|
|||
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
|
||||
}
|
||||
|
||||
func (statement *Statement) Omit(columns ...string) {
|
||||
newColumns := col2NewCols(columns...)
|
||||
for _, nc := range newColumns {
|
||||
statement.columnMap[nc] = false
|
||||
}
|
||||
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
|
||||
}
|
||||
|
||||
func (statement *Statement) Limit(limit int, start ...int) {
|
||||
statement.LimitN = limit
|
||||
if len(start) > 0 {
|
||||
|
@ -251,10 +269,16 @@ func (statement *Statement) genColumnStr() string {
|
|||
table := statement.RefTable
|
||||
colNames := make([]string, 0)
|
||||
for _, col := range table.Columns {
|
||||
if col.MapType != ONLYTODB {
|
||||
colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name))
|
||||
if statement.OmitStr != "" {
|
||||
if _, ok := statement.columnMap[col.Name]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if col.MapType == ONLYTODB {
|
||||
continue
|
||||
}
|
||||
colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name))
|
||||
}
|
||||
return strings.Join(colNames, ", ")
|
||||
}
|
||||
|
||||
|
|
5
table.go
5
table.go
|
@ -313,6 +313,11 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc
|
|||
continue
|
||||
}
|
||||
}
|
||||
if session.Statement.OmitStr != "" {
|
||||
if _, ok := session.Statement.columnMap[col.Name]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime {
|
||||
args = append(args, time.Now())
|
||||
|
|
|
@ -43,5 +43,9 @@ func unTitle(src string) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
if len(src) == 1 {
|
||||
return strings.ToLower(string(src[0]))
|
||||
} else {
|
||||
return strings.ToLower(string(src[0])) + src[1:]
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue