diff --git a/README.md b/README.md index f4bee6b6..4cb4ac13 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,11 @@ has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id) has, err := engine.SQL("select id from user").Get(&id) // SELECT id FROM user WHERE name = ? +var id int64 +var name string +has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) +// SELECT id, name FROM user LIMIT 1 + var valuesMap = make(map[string]string) has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap) // SELECT * FROM user WHERE id = ? @@ -234,7 +239,11 @@ err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean inter }) // SELECT * FROM user Limit 0, 100 // SELECT * FROM user Limit 101, 100 +``` +You can use rows which is similiar with `sql.Rows` + +```Go rows, err := engine.Rows(&User{Name:name}) // SELECT * FROM user defer rows.Close() @@ -244,6 +253,19 @@ for rows.Next() { } ``` +or + +```Go +rows, err := engine.Cols("name", "age").Rows(&User{Name:name}) +// SELECT * FROM user +defer rows.Close() +for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) +} +``` + * `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on. ```Go diff --git a/README_CN.md b/README_CN.md index 500bb1fb..c87aa079 100644 --- a/README_CN.md +++ b/README_CN.md @@ -158,6 +158,11 @@ has, err := engine.Table(&user).Where("name = ?", name).Cols("id").Get(&id) has, err := engine.SQL("select id from user").Get(&id) // SELECT id FROM user WHERE name = ? +var id int64 +var name string +has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) +// SELECT id, name FROM user LIMIT 1 + var valuesMap = make(map[string]string) has, err := engine.Table(&user).Where("id = ?", id).Get(&valuesMap) // SELECT * FROM user WHERE id = ? @@ -231,7 +236,11 @@ err := engine.BufferSize(100).Iterate(&User{Name:name}, func(idx int, bean inter }) // SELECT * FROM user Limit 0, 100 // SELECT * FROM user Limit 101, 100 +``` +Rows 的用法类似 `sql.Rows`。 + +```Go rows, err := engine.Rows(&User{Name:name}) // SELECT * FROM user defer rows.Close() @@ -241,6 +250,19 @@ for rows.Next() { } ``` +或者 + +```Go +rows, err := engine.Cols("name", "age").Rows(&User{Name:name}) +// SELECT * FROM user +defer rows.Close() +for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) +} +``` + * `Update` 更新数据,除非使用Cols,AllCols函数指明,默认只更新非空和非0的字段 ```Go diff --git a/doc.go b/doc.go index d0653232..a1565806 100644 --- a/doc.go +++ b/doc.go @@ -67,6 +67,11 @@ There are 8 major ORM methods and many helpful methods to use to operate databas has, err := engine.Table("user").Where("name = ?", name).Get(&id) // SELECT id FROM user WHERE name = ? LIMIT 1 + var id int64 + var name string + has, err := engine.Table(&user).Cols("id", "name").Get(&id, &name) + // SELECT id, name FROM user LIMIT 1 + 3. Query multiple records from database var sliceOfStructs []Struct @@ -97,6 +102,17 @@ another is Rows err = rows.Scan(bean) } +or + + rows, err := engine.Cols("name", "age").Rows(...) + // SELECT * FROM user + defer rows.Close() + for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) + } + 5. Update one or more records affected, err := engine.ID(...).Update(&user) diff --git a/engine.go b/engine.go index 3681ff3c..7d72dea3 100644 --- a/engine.go +++ b/engine.go @@ -1135,10 +1135,10 @@ func (engine *Engine) Delete(beans ...interface{}) (int64, error) { // Get retrieve one record from table, bean's non-empty fields // are conditions -func (engine *Engine) Get(bean interface{}) (bool, error) { +func (engine *Engine) Get(beans ...interface{}) (bool, error) { session := engine.NewSession() defer session.Close() - return session.Get(bean) + return session.Get(beans...) } // Exist returns true if the record exist otherwise return false diff --git a/integrations/rows_test.go b/integrations/rows_test.go index f68030a4..a5648675 100644 --- a/integrations/rows_test.go +++ b/integrations/rows_test.go @@ -160,5 +160,49 @@ func TestRowsSpecTableName(t *testing.T) { assert.NoError(t, err) cnt++ } + assert.NoError(t, rows.Err()) assert.EqualValues(t, 1, cnt) } + +func TestRowsScanVars(t *testing.T) { + type RowsScanVars struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assert.NoError(t, testEngine.Sync2(new(RowsScanVars))) + + cnt, err := testEngine.Insert(&RowsScanVars{ + Name: "xlw", + Age: 42, + }, &RowsScanVars{ + Name: "xlw2", + Age: 24, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + rows, err := testEngine.Cols("name", "age").Rows(new(RowsScanVars)) + assert.NoError(t, err) + defer rows.Close() + + cnt = 0 + for rows.Next() { + var name string + var age int + err = rows.Scan(&name, &age) + assert.NoError(t, err) + if cnt == 0 { + assert.EqualValues(t, "xlw", name) + assert.EqualValues(t, 42, age) + } else if cnt == 1 { + assert.EqualValues(t, "xlw2", name) + assert.EqualValues(t, 24, age) + } + cnt++ + } + assert.NoError(t, rows.Err()) + assert.EqualValues(t, 2, cnt) +} diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index 4fc30adb..6f4c1dbe 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -890,3 +890,28 @@ func TestGetTime(t *testing.T) { assert.True(t, has) assert.EqualValues(t, gts.CreateTime.Format(time.RFC3339), gn.Format(time.RFC3339)) } + +func TestGetVars(t *testing.T) { + type GetVars struct { + Id int64 + Name string + Age int + } + + assert.NoError(t, PrepareEngine()) + assertSync(t, new(GetVars)) + + _, err := testEngine.Insert(&GetVars{ + Name: "xlw", + Age: 42, + }) + assert.NoError(t, err) + + var name string + var age int + has, err := testEngine.Table(new(GetVars)).Cols("name", "age").Get(&name, &age) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, "xlw", name) + assert.EqualValues(t, 42, age) +} diff --git a/interface.go b/interface.go index 5d68f536..42dc9a0a 100644 --- a/interface.go +++ b/interface.go @@ -37,7 +37,7 @@ type Interface interface { Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error FindAndCount(interface{}, ...interface{}) (int64, error) - Get(interface{}) (bool, error) + Get(...interface{}) (bool, error) GroupBy(keys string) *Session ID(interface{}) *Session In(string, ...interface{}) *Session diff --git a/rows.go b/rows.go index 8e7cc075..76fc1e90 100644 --- a/rows.go +++ b/rows.go @@ -11,7 +11,6 @@ import ( "xorm.io/builder" "xorm.io/xorm/core" - "xorm.io/xorm/internal/utils" ) // Rows rows wrapper a rows to @@ -93,17 +92,26 @@ func (rows *Rows) Err() error { } // Scan row record to bean properties -func (rows *Rows) Scan(bean interface{}) error { +func (rows *Rows) Scan(beans ...interface{}) error { if rows.Err() != nil { return rows.Err() } - if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { - return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) + var bean = beans[0] + var tp = reflect.TypeOf(bean) + if tp.Kind() == reflect.Ptr { + tp = tp.Elem() } + var beanKind = tp.Kind() - if err := rows.session.statement.SetRefBean(bean); err != nil { - return err + if len(beans) == 1 { + if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { + return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) + } + + if err := rows.session.statement.SetRefBean(bean); err != nil { + return err + } } fields, err := rows.rows.Columns() @@ -115,14 +123,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } - scanResults, err := rows.session.row2Slice(rows.rows, fields, types, bean) - if err != nil { - return err - } - - dataStruct := utils.ReflectValue(bean) - _, err = rows.session.slice2Bean(scanResults, fields, bean, &dataStruct, rows.session.statement.RefTable) - if err != nil { + if err := rows.session.scan(rows.rows, rows.session.statement.RefTable, beanKind, beans, types, fields); err != nil { return err } diff --git a/session_get.go b/session_get.go index 22b116a9..a82cae92 100644 --- a/session_get.go +++ b/session_get.go @@ -28,11 +28,11 @@ var ( // Get retrieve one record from database, bean's non-empty fields // will be as conditions -func (session *Session) Get(bean interface{}) (bool, error) { +func (session *Session) Get(beans ...interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } - return session.get(bean) + return session.get(beans...) } func isPtrOfTime(v interface{}) bool { @@ -48,14 +48,17 @@ func isPtrOfTime(v interface{}) bool { return el.Type().ConvertibleTo(schemas.TimeType) } -func (session *Session) get(bean interface{}) (bool, error) { +func (session *Session) get(beans ...interface{}) (bool, error) { defer session.resetStatement() if session.statement.LastError != nil { return false, session.statement.LastError } + if len(beans) == 0 { + return false, errors.New("needs at least one parameter for get") + } - beanValue := reflect.ValueOf(bean) + beanValue := reflect.ValueOf(beans[0]) if beanValue.Kind() != reflect.Ptr { return false, errors.New("needs a pointer to a value") } else if beanValue.Elem().Kind() == reflect.Ptr { @@ -64,8 +67,9 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, ErrObjectIsNil } - if beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(bean) { - if err := session.statement.SetRefBean(bean); err != nil { + var isStruct = beanValue.Elem().Kind() == reflect.Struct && !isPtrOfTime(beans[0]) + if isStruct { + if err := session.statement.SetRefBean(beans[0]); err != nil { return false, err } } @@ -79,7 +83,7 @@ func (session *Session) get(bean interface{}) (bool, error) { return false, ErrTableNotFound } session.statement.Limit(1) - sqlStr, args, err = session.statement.GenGetSQL(bean) + sqlStr, args, err = session.statement.GenGetSQL(beans[0]) if err != nil { return false, err } @@ -90,10 +94,10 @@ func (session *Session) get(bean interface{}) (bool, error) { table := session.statement.RefTable - if session.statement.ColumnMap.IsEmpty() && session.canCache() && beanValue.Elem().Kind() == reflect.Struct { + if session.statement.ColumnMap.IsEmpty() && session.canCache() && isStruct { if cacher := session.engine.GetCacher(session.statement.TableName()); cacher != nil && !session.statement.GetUnscoped() { - has, err := session.cacheGet(bean, sqlStr, args...) + has, err := session.cacheGet(beans[0], sqlStr, args...) if err != ErrCacheFailed { return has, err } @@ -101,12 +105,12 @@ func (session *Session) get(bean interface{}) (bool, error) { } context := session.statement.Context - if context != nil { + if context != nil && isStruct { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { session.engine.logger.Debugf("hit context cache: %s", sqlStr) - structValue := reflect.Indirect(reflect.ValueOf(bean)) + structValue := reflect.Indirect(reflect.ValueOf(beans[0])) structValue.Set(reflect.Indirect(reflect.ValueOf(res))) session.lastSQL = "" session.lastSQLArgs = nil @@ -114,13 +118,13 @@ func (session *Session) get(bean interface{}) (bool, error) { } } - has, err := session.nocacheGet(beanValue.Elem().Kind(), table, bean, sqlStr, args...) + has, err := session.nocacheGet(beanValue.Elem().Kind(), table, beans, sqlStr, args...) if err != nil || !has { return has, err } - if context != nil { - context.Put(fmt.Sprintf("%v-%v", sqlStr, args), bean) + if context != nil && isStruct { + context.Put(fmt.Sprintf("%v-%v", sqlStr, args), beans[0]) } return true, nil @@ -148,7 +152,7 @@ func isScannableStruct(bean interface{}, typeLen int) bool { return true } -func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, bean interface{}, sqlStr string, args ...interface{}) (bool, error) { +func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, beans []interface{}, sqlStr string, args ...interface{}) (bool, error) { rows, err := session.queryRows(sqlStr, args...) if err != nil { return false, err @@ -168,27 +172,39 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, if err != nil { return true, err } - switch beanKind { - case reflect.Struct: - if !isScannableStruct(bean, len(types)) { - break - } - return session.getStruct(rows, types, fields, table, bean) - case reflect.Slice: - return session.getSlice(rows, types, fields, bean) - case reflect.Map: - return session.getMap(rows, types, fields, bean) - } - return session.getVars(rows, types, fields, bean) + return true, session.scan(rows, table, beanKind, beans, types, fields) } -func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { +func (session *Session) scan(rows *core.Rows, table *schemas.Table, firstBeanKind reflect.Kind, beans []interface{}, types []*sql.ColumnType, fields []string) error { + if len(beans) == 1 { + bean := beans[0] + switch firstBeanKind { + case reflect.Struct: + if !isScannableStruct(bean, len(types)) { + break + } + return session.getStruct(rows, types, fields, table, bean) + case reflect.Slice: + return session.getSlice(rows, types, fields, bean) + case reflect.Map: + return session.getMap(rows, types, fields, bean) + } + } + + if len(beans) != len(types) { + return fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) + } + + return session.engine.scan(rows, fields, types, beans...) +} + +func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) error { switch t := bean.(type) { case *[]string: res, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { - return true, err + return err } var needAppend = len(*t) == 0 // both support slice is empty or has been initlized @@ -199,17 +215,17 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field (*t)[i] = r.(*sql.NullString).String } } - return true, nil + return nil case *[]interface{}: scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { - return true, err + return err } var needAppend = len(*t) == 0 for ii := range fields { s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) if err != nil { - return true, err + return err } if needAppend { *t = append(*t, s) @@ -217,54 +233,45 @@ func (session *Session) getSlice(rows *core.Rows, types []*sql.ColumnType, field (*t)[ii] = s } } - return true, nil + return nil default: - return true, fmt.Errorf("unspoorted slice type: %t", t) + return fmt.Errorf("unspoorted slice type: %t", t) } } -func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) (bool, error) { +func (session *Session) getMap(rows *core.Rows, types []*sql.ColumnType, fields []string, bean interface{}) error { switch t := bean.(type) { case *map[string]string: scanResults, err := session.engine.scanStringInterface(rows, fields, types) if err != nil { - return true, err + return err } for ii, key := range fields { (*t)[key] = scanResults[ii].(*sql.NullString).String } - return true, nil + return nil case *map[string]interface{}: scanResults, err := session.engine.scanInterfaces(rows, fields, types) if err != nil { - return true, err + return err } for ii, key := range fields { s, err := convert.Interface2Interface(session.engine.DatabaseTZ, scanResults[ii]) if err != nil { - return true, err + return err } (*t)[key] = s } - return true, nil + return nil default: - return true, fmt.Errorf("unspoorted map type: %t", t) + return fmt.Errorf("unspoorted map type: %t", t) } } -func (session *Session) getVars(rows *core.Rows, types []*sql.ColumnType, fields []string, beans ...interface{}) (bool, error) { - if len(beans) != len(types) { - return false, fmt.Errorf("expected columns %d, but only %d variables", len(types), len(beans)) - } - - err := session.engine.scan(rows, fields, types, beans...) - return true, err -} - -func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) (bool, error) { +func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fields []string, table *schemas.Table, bean interface{}) error { scanResults, err := session.row2Slice(rows, fields, types, bean) if err != nil { - return false, err + return err } // close it before convert data rows.Close() @@ -272,10 +279,10 @@ func (session *Session) getStruct(rows *core.Rows, types []*sql.ColumnType, fiel dataStruct := utils.ReflectValue(bean) _, err = session.slice2Bean(scanResults, fields, bean, &dataStruct, table) if err != nil { - return true, err + return err } - return true, session.executeProcessors() + return session.executeProcessors() } func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { @@ -354,7 +361,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf cacheBean := cacher.GetBean(tableName, sid) if cacheBean == nil { cacheBean = bean - has, err = session.nocacheGet(reflect.Struct, table, cacheBean, sqlStr, args...) + has, err = session.nocacheGet(reflect.Struct, table, []interface{}{cacheBean}, sqlStr, args...) if err != nil || !has { return has, err }