diff --git a/base_test.go b/base_test.go index 1f4a6ecc..a58e0304 100644 --- a/base_test.go +++ b/base_test.go @@ -1115,6 +1115,19 @@ func testMetaInfo(engine *Engine, t *testing.T) { } } +func testIterate(engine *Engine, t *testing.T) { + err := engine.Omit("is_man").Iterate(new(Userinfo), func(idx int, bean interface{}) error { + user := bean.(*Userinfo) + fmt.Println(idx, "--", user) + return nil + }) + + if err != nil { + t.Error(err) + panic(err) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -1195,6 +1208,8 @@ func testAll2(engine *Engine, t *testing.T) { //testInt32Id(engine, t) fmt.Println("-------------- testMetaInfo --------------") testMetaInfo(engine, t) + fmt.Println("-------------- testIterate --------------") + testIterate(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/engine.go b/engine.go index e9e2d683..a811f390 100644 --- a/engine.go +++ b/engine.go @@ -250,6 +250,12 @@ func (engine *Engine) Cols(columns ...string) *Session { return session.Cols(columns...) } +func (engine *Engine) Omit(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Omit(columns...) +} + /*func (engine *Engine) Trans(t string) *Session { session := engine.NewSession() session.IsAutoClose = true @@ -799,6 +805,12 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { return session.Find(beans, condiBeans...) } +func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { + session := engine.NewSession() + defer session.Close() + return session.Iterate(bean, fun) +} + func (engine *Engine) Count(bean interface{}) (int64, error) { session := engine.NewSession() defer session.Close() diff --git a/iterator.go b/iterator.go new file mode 100644 index 00000000..f679ae18 --- /dev/null +++ b/iterator.go @@ -0,0 +1,35 @@ +package xorm + +import ( + "database/sql" +) + +type Iterator struct { + session *Session + startId int + rows *sql.Rows +} + +func (iter *Iterator) IsValid() bool { + return iter.session != nil && iter.rows != nil +} + +/* +func (iter *Iterator) Next(bean interface{}) (bool, error) { + if !iter.IsValid() { + return errors.New("iterator is not valied.") + } + if iter.rows.Next() { + iter.rows.Scan(...) + } +}*/ + +// close the iterator +func (iter *Iterator) Close() { + if iter.rows != nil { + iter.rows.Close() + } + if iter.session != nil && iter.session.IsAutoClose { + iter.session.Close() + } +} diff --git a/postgres_test.go b/postgres_test.go index 1afcd7f5..6a5e43ff 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -101,6 +101,10 @@ func TestPostgres2(t *testing.T) { testCreatedAndUpdated(engine, t) fmt.Println("-------------- testIndexAndUnique --------------") testIndexAndUnique(engine, t) + fmt.Println("-------------- testMetaInfo --------------") + testMetaInfo(engine, t) + fmt.Println("-------------- testIterate --------------") + testIterate(engine, t) } func BenchmarkPostgresNoCache(t *testing.B) { diff --git a/session.go b/session.go index 45ea082b..0da193da 100644 --- a/session.go +++ b/session.go @@ -82,6 +82,11 @@ func (session *Session) Cols(columns ...string) *Session { return session } +func (session *Session) Omit(columns ...string) *Session { + session.Statement.Omit(columns...) + return session +} + // Method NoAutoTime means do not automatically give created field and updated field // the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { @@ -663,6 +668,71 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter return nil } +type IterFunc func(idx int, bean interface{}) error + +func (session *Session) Iterate(bean interface{}, fun IterFunc) error { + err := session.newDb() + if err != nil { + return err + } + + defer session.Statement.Init() + if session.IsAutoClose { + defer session.Close() + } + + var sql string + var args []interface{} + session.Statement.RefTable = session.Engine.AutoMap(bean) + if session.Statement.RawSQL == "" { + sql, args = session.Statement.genGetSql(bean) + } else { + sql = session.Statement.RawSQL + args = session.Statement.RawParams + } + + for _, filter := range session.Engine.Filters { + sql = filter.Do(sql, session) + } + + session.Engine.LogSQL(sql) + session.Engine.LogSQL(args) + + s, err := session.Db.Prepare(sql) + if err != nil { + return err + } + defer s.Close() + rows, err := s.Query(args...) + if err != nil { + return err + } + defer rows.Close() + + fields, err := rows.Columns() + if err != nil { + return err + } + t := reflect.Indirect(reflect.ValueOf(bean)).Type() + b := reflect.New(t).Interface() + i := 0 + for rows.Next() { + result, err := row2map(rows, fields) + if err == nil { + err = session.scanMapIntoStruct(b, result) + } + if err == nil { + err = fun(i, b) + i = i + 1 + } + if err != nil { + return err + } + } + + return nil +} + // get retrieve one record from database func (session *Session) Get(bean interface{}) (bool, error) { err := session.newDb() @@ -982,77 +1052,91 @@ func (session *Session) DropAll() error { return nil } +func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err error) { + result := make(map[string][]byte) + var scanResultContainers []interface{} + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := rows.Scan(scanResultContainers...); err != nil { + return nil, err + } + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + //fmt.Println("ignore ...", key, rawValue) + continue + } + aa := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + var str string + switch aa.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + result[key] = []byte(str) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + result[key] = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + result[key] = []byte(str) + case reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + result[key] = rawValue.Interface().([]byte) + default: + //session.Engine.LogError("Unsupported type") + } + case reflect.String: + str = vv.String() + result[key] = []byte(str) + //时间类型 + case reflect.Struct: + if aa.String() == "time.Time" { + str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") + result[key] = []byte(str) + } else { + //session.Engine.LogError("Unsupported struct type") + } + default: + //session.Engine.LogError("Unsupported type") + } + } + return result, nil +} + +func rows2maps(rows *sql.Rows) (resultsSlice []map[string][]byte, err error) { + fields, err := rows.Columns() + if err != nil { + return nil, err + } + for rows.Next() { + result, err := row2map(rows, fields) + if err != nil { + return nil, err + } + resultsSlice = append(resultsSlice, result) + } + + return resultsSlice, nil +} + func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { s, err := db.Prepare(sql) if err != nil { return nil, err } defer s.Close() - res, err := s.Query(params...) + rows, err := s.Query(params...) if err != nil { return nil, err } - defer res.Close() - fields, err := res.Columns() - if err != nil { - return nil, err - } - for res.Next() { - result := make(map[string][]byte) - var scanResultContainers []interface{} - for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers = append(scanResultContainers, &scanResultContainer) - } - if err := res.Scan(scanResultContainers...); err != nil { - return nil, err - } - for ii, key := range fields { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + defer rows.Close() - //if row is null then ignore - if rawValue.Interface() == nil { - //fmt.Println("ignore ...", key, rawValue) - continue - } - aa := reflect.TypeOf(rawValue.Interface()) - vv := reflect.ValueOf(rawValue.Interface()) - var str string - switch aa.Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - str = strconv.FormatInt(vv.Int(), 10) - result[key] = []byte(str) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - str = strconv.FormatUint(vv.Uint(), 10) - result[key] = []byte(str) - case reflect.Float32, reflect.Float64: - str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) - result[key] = []byte(str) - case reflect.Slice: - switch aa.Elem().Kind() { - case reflect.Uint8: - result[key] = rawValue.Interface().([]byte) - default: - //session.Engine.LogError("Unsupported type") - } - case reflect.String: - str = vv.String() - result[key] = []byte(str) - //时间类型 - case reflect.Struct: - if aa.String() == "time.Time" { - str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") - result[key] = []byte(str) - } else { - //session.Engine.LogError("Unsupported struct type") - } - default: - //session.Engine.LogError("Unsupported type") - } - } - resultsSlice = append(resultsSlice, result) - } - return resultsSlice, nil + return rows2maps(rows) } func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { diff --git a/statement.go b/statement.go index 1d7ef006..9129be3b 100644 --- a/statement.go +++ b/statement.go @@ -22,6 +22,7 @@ type Statement struct { HavingStr string ColumnStr string columnMap map[string]bool + OmitStr string ConditionStr string AltTableName string RawSQL string @@ -47,6 +48,7 @@ func (statement *Statement) Init() { statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnStr = "" + statement.OmitStr = "" statement.columnMap = make(map[string]bool) statement.ConditionStr = "" statement.AltTableName = "" @@ -82,9 +84,10 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string, for _, col := range table.Columns { fieldValue := col.ValueOf(bean) fieldType := reflect.TypeOf(fieldValue.Interface()) - val := fieldValue.Interface() + var val interface{} switch fieldType.Kind() { case reflect.Bool: + val = fieldValue.Interface() case reflect.String: if fieldValue.String() == "" { continue @@ -92,19 +95,24 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string, // for MyString, should convert to string or panic if fieldType.String() != reflect.String.String() { val = fieldValue.String() + } else { + val = fieldValue.Interface() } case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: if fieldValue.Int() == 0 { continue } + val = fieldValue.Interface() case reflect.Float32, reflect.Float64: if fieldValue.Float() == 0.0 { continue } + val = fieldValue.Interface() case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: if fieldValue.Uint() == 0 { continue } + val = fieldValue.Interface() case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) @@ -121,6 +129,8 @@ func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string, } else { continue } + } else { + val = fieldValue.Interface() } } case reflect.Array, reflect.Slice, reflect.Map: @@ -219,6 +229,14 @@ func (statement *Statement) Cols(columns ...string) { statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } +func (statement *Statement) Omit(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.columnMap[nc] = false + } + statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) +} + func (statement *Statement) Limit(limit int, start ...int) { statement.LimitN = limit if len(start) > 0 { @@ -251,9 +269,15 @@ func (statement *Statement) genColumnStr() string { table := statement.RefTable colNames := make([]string, 0) for _, col := range table.Columns { - if col.MapType != ONLYTODB { - colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name)) + if statement.OmitStr != "" { + if _, ok := statement.columnMap[col.Name]; ok { + continue + } } + if col.MapType == ONLYTODB { + continue + } + colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name)) } return strings.Join(colNames, ", ") } diff --git a/table.go b/table.go index b7303808..1ecdc0cb 100644 --- a/table.go +++ b/table.go @@ -313,6 +313,11 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc continue } } + if session.Statement.OmitStr != "" { + if _, ok := session.Statement.columnMap[col.Name]; ok { + continue + } + } if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime { args = append(args, time.Now()) diff --git a/xorm/lang.go b/xorm/lang.go index d7f02d20..17d09a98 100644 --- a/xorm/lang.go +++ b/xorm/lang.go @@ -43,5 +43,9 @@ func unTitle(src string) string { return "" } - return strings.ToLower(string(src[0])) + src[1:] + if len(src) == 1 { + return strings.ToLower(string(src[0])) + } else { + return strings.ToLower(string(src[0])) + src[1:] + } }