Merge branch 'master' into lunny/update_join

This commit is contained in:
Lunny Xiao 2023-07-25 18:52:23 +08:00
commit 77a1305ef6
6 changed files with 253 additions and 83 deletions

View File

@ -125,6 +125,11 @@ func TestWithTableName(t *testing.T) {
total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{}) total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, total) assert.EqualValues(t, 2, total)
// the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by
total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{})
assert.NoError(t, err)
assert.EqualValues(t, 2, total)
} }
func TestCountWithSelectCols(t *testing.T) { func TestCountWithSelectCols(t *testing.T) {

View File

@ -30,7 +30,7 @@ func TestQueryString(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar2))) assert.NoError(t, testEngine.Sync(new(GetVar2)))
var data = GetVar2{ data := GetVar2{
Msg: "hi", Msg: "hi",
Age: 28, Age: 28,
Money: 1.5, Money: 1.5,
@ -58,7 +58,7 @@ func TestQueryString2(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar3))) assert.NoError(t, testEngine.Sync(new(GetVar3)))
var data = GetVar3{ data := GetVar3{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -95,7 +95,7 @@ func TestQueryInterface(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVarInterface))) assert.NoError(t, testEngine.Sync(new(GetVarInterface)))
var data = GetVarInterface{ data := GetVarInterface{
Msg: "hi", Msg: "hi",
Age: 28, Age: 28,
Money: 1.5, Money: 1.5,
@ -128,7 +128,7 @@ func TestQueryNoParams(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(QueryNoParams))) assert.NoError(t, testEngine.Sync(new(QueryNoParams)))
var q = QueryNoParams{ q := QueryNoParams{
Msg: "message", Msg: "message",
Age: 20, Age: 20,
Money: 3000, Money: 3000,
@ -172,7 +172,7 @@ func TestQueryStringNoParam(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar4))) assert.NoError(t, testEngine.Sync(new(GetVar4)))
var data = GetVar4{ data := GetVar4{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -209,7 +209,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar6))) assert.NoError(t, testEngine.Sync(new(GetVar6)))
var data = GetVar6{ data := GetVar6{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -246,7 +246,7 @@ func TestQueryInterfaceNoParam(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar5))) assert.NoError(t, testEngine.Sync(new(GetVar5)))
var data = GetVar5{ data := GetVar5{
Msg: false, Msg: false,
} }
_, err := testEngine.Insert(data) _, err := testEngine.Insert(data)
@ -280,7 +280,7 @@ func TestQueryWithBuilder(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(QueryWithBuilder))) assert.NoError(t, testEngine.Sync(new(QueryWithBuilder)))
var q = QueryWithBuilder{ q := QueryWithBuilder{
Msg: "message", Msg: "message",
Age: 20, Age: 20,
Money: 3000, Money: 3000,
@ -329,14 +329,14 @@ func TestJoinWithSubQuery(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart))) assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart)))
var depart = JoinWithSubQueryDepart{ depart := JoinWithSubQueryDepart{
Name: "depart1", Name: "depart1",
} }
cnt, err := testEngine.Insert(&depart) cnt, err := testEngine.Insert(&depart)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var q = JoinWithSubQuery1{ q := JoinWithSubQuery1{
Msg: "message", Msg: "message",
DepartId: depart.Id, DepartId: depart.Id,
Money: 3000, Money: 3000,
@ -401,7 +401,7 @@ func TestQueryBLOBInMySQL(t *testing.T) {
} }
const N = 10 const N = 10
var data = []Avatar{} data := []Avatar{}
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
// allocate a []byte that is as twice big as the last one // allocate a []byte that is as twice big as the last one
// so that the underlying buffer will need to reallocate when querying // so that the underlying buffer will need to reallocate when querying
@ -448,3 +448,54 @@ func TestQueryBLOBInMySQL(t *testing.T) {
} }
} }
} }
func TestRowsReset(t *testing.T) {
assert.NoError(t, PrepareEngine())
type RowsReset1 struct {
Id int64
Name string
}
type RowsReset2 struct {
Id int64
Name string
}
assert.NoError(t, testEngine.Sync(new(RowsReset1), new(RowsReset2)))
data := []RowsReset1{
{0, "1"},
{0, "2"},
{0, "3"},
}
_, err := testEngine.Insert(data)
assert.NoError(t, err)
data2 := []RowsReset2{
{0, "4"},
{0, "5"},
{0, "6"},
}
_, err = testEngine.Insert(data2)
assert.NoError(t, err)
sess := testEngine.NewSession()
defer sess.Close()
rows, err := sess.Rows(new(RowsReset1))
assert.NoError(t, err)
for rows.Next() {
var data1 RowsReset1
assert.NoError(t, rows.Scan(&data1))
}
rows.Close()
var rrs []RowsReset2
assert.NoError(t, sess.Find(&rrs))
assert.Len(t, rrs, 3)
assert.EqualValues(t, "4", rrs[0].Name)
assert.EqualValues(t, "5", rrs[1].Name)
assert.EqualValues(t, "6", rrs[2].Name)
}

View File

@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil { if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, columnStr, true); err != nil { if err := statement.writeSelect(buf, columnStr, true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil
@ -153,12 +153,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
selectSQL = "count(*)" selectSQL = "count(*)"
} }
} }
var subQuerySelect string
if statement.GroupByStr != "" {
subQuerySelect = statement.GroupByStr
} else {
subQuerySelect = selectSQL
}
buf := builder.NewWriter() buf := builder.NewWriter()
if statement.GroupByStr != "" { if statement.GroupByStr != "" {
@ -167,7 +161,14 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
} }
} }
if err := statement.writeSelect(buf, subQuerySelect, false); err != nil { var subQuerySelect string
if statement.GroupByStr != "" {
subQuerySelect = statement.GroupByStr
} else {
subQuerySelect = selectSQL
}
if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil {
return "", nil, err return "", nil, err
} }
@ -364,7 +365,7 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s
return err return err
} }
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error { func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error {
if err := statement.writeSelectColumns(buf, columnStr); err != nil { if err := statement.writeSelectColumns(buf, columnStr); err != nil {
return err return err
} }
@ -380,8 +381,10 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
if err := statement.writeHaving(buf); err != nil { if err := statement.writeHaving(buf); err != nil {
return err return err
} }
if err := statement.writeOrderBys(buf); err != nil { if needOrderBy {
return err if err := statement.writeOrderBys(buf); err != nil {
return err
}
} }
dialect := statement.dialect dialect := statement.dialect
@ -519,7 +522,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
buf := builder.NewWriter() buf := builder.NewWriter()
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil { if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
return "", nil, err return "", nil, err
} }
return buf.String(), buf.Args(), nil return buf.String(), buf.Args(), nil

View File

@ -427,7 +427,158 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
} }
} }
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error { func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
switch t := m.(type) {
case map[string]interface{}:
conds := []builder.Cond{}
for k, v := range t {
conds = append(conds, builder.Eq{k: v})
}
return conds, nil
case map[string]string:
conds := []builder.Cond{}
for k, v := range t {
conds = append(conds, builder.Eq{k: v})
}
return conds, nil
default:
return nil, fmt.Errorf("unsupported condition map type %v", t)
}
}
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
if v.Type().Kind() != reflect.Struct {
return nil
}
table := statement.RefTable
if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
return nil
}
verValue, err := table.VersionColumn().ValueOfV(&v)
if err != nil {
return err
}
if verValue == nil {
return nil
}
if hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
return err
}
return nil
}
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
for i, expr := range statement.IncrColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
return nil
}
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
// for update action to like "column = column - ?"
for i, expr := range statement.DecrColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
return nil
}
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
// for update action to like "column = expression"
for i, expr := range statement.ExprColumns {
if i > 0 || hasPreviousSet {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
switch tp := expr.Arg.(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
return err
}
case *builder.Builder:
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
return err
}
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
return err
}
if _, err := fmt.Fprint(w, ")"); err != nil {
return err
}
default:
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
return err
}
w.Append(expr.Arg)
}
}
return nil
}
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
previousLen := w.Len()
for i, colName := range colNames {
if i > 0 {
if _, err := fmt.Fprint(w, ", "); err != nil {
return err
}
}
if _, err := fmt.Fprint(w, colName); err != nil {
return err
}
}
w.Append(args...)
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
return err
}
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
return err
}
return nil
}
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil { if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
return err return err
} }
@ -450,17 +601,16 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil { if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
return err return err
} }
for i, colName := range colNames { previousLen := updateWriter.Len()
if i > 0 {
if _, err := fmt.Fprint(updateWriter, ", "); err != nil { if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
return err return err
} }
}
if _, err := fmt.Fprint(updateWriter, colName); err != nil { // if no columns to be updated, return error
return err if previousLen == updateWriter.Len() {
} return ErrNoColumnsTobeUpdated
} }
updateWriter.Append(args...)
// write from // write from
if err := statement.writeUpdateFrom(updateWriter); err != nil { if err := statement.writeUpdateFrom(updateWriter); err != nil {

View File

@ -144,6 +144,8 @@ func (rows *Rows) Close() error {
defer rows.session.Close() defer rows.session.Close()
} }
defer rows.session.resetStatement()
if rows.rows != nil { if rows.rows != nil {
return rows.rows.Close() return rows.rows.Close()
} }

View File

@ -5,17 +5,17 @@
package xorm package xorm
import ( import (
"errors"
"reflect" "reflect"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/internal/statements"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// enumerated all errors // enumerated all errors
var ( var (
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
) )
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
v := utils.ReflectValue(bean) v := utils.ReflectValue(bean)
t := v.Type() t := v.Type()
var colNames []string
var args []interface{}
// handle before update processors // handle before update processors
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
closure(bean) closure(bean)
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
// -- // --
var colNames []string
var args []interface{}
var err error var err error
isMap := t.Kind() == reflect.Map isMap := t.Kind() == reflect.Map
isStruct := t.Kind() == reflect.Struct isStruct := t.Kind() == reflect.Struct
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
// for update action to like "column = column + ?"
incColumns := session.statement.IncrColumns
for _, expr := range incColumns {
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
args = append(args, expr.Arg)
}
// for update action to like "column = column - ?"
decColumns := session.statement.DecrColumns
for _, expr := range decColumns {
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
args = append(args, expr.Arg)
}
// for update action to like "column = expression"
exprColumns := session.statement.ExprColumns
for _, expr := range exprColumns {
switch tp := expr.Arg.(type) {
case string:
if len(tp) == 0 {
tp = "''"
}
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
case *builder.Builder:
subQuery, subArgs, err := builder.ToSQL(tp)
if err != nil {
return 0, err
}
subQuery = session.statement.ReplaceQuote(subQuery)
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
args = append(args, subArgs...)
default:
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
args = append(args, expr.Arg)
}
}
if err = session.statement.ProcessIDParam(); err != nil { if err = session.statement.ProcessIDParam(); err != nil {
return 0, err return 0, err
} }
@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
verValue *reflect.Value verValue *reflect.Value
) )
if doIncVer { if doIncVer {
verValue, err = table.VersionColumn().ValueOf(bean) verValue, err = table.VersionColumn().ValueOfV(&v)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if verValue != nil { if verValue != nil {
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()}) cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
} }
} }
if len(colNames) == 0 {
return 0, ErrNoColumnsTobeUpdated
}
updateWriter := builder.NewWriter() updateWriter := builder.NewWriter()
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil { if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
return 0, err return 0, err
} }