Merge branch 'master' into lunny/update_join
This commit is contained in:
commit
77a1305ef6
|
@ -125,6 +125,11 @@ func TestWithTableName(t *testing.T) {
|
|||
total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 2, total)
|
||||
|
||||
// the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by
|
||||
total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 2, total)
|
||||
}
|
||||
|
||||
func TestCountWithSelectCols(t *testing.T) {
|
||||
|
|
|
@ -30,7 +30,7 @@ func TestQueryString(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(GetVar2)))
|
||||
|
||||
var data = GetVar2{
|
||||
data := GetVar2{
|
||||
Msg: "hi",
|
||||
Age: 28,
|
||||
Money: 1.5,
|
||||
|
@ -58,7 +58,7 @@ func TestQueryString2(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(GetVar3)))
|
||||
|
||||
var data = GetVar3{
|
||||
data := GetVar3{
|
||||
Msg: false,
|
||||
}
|
||||
_, err := testEngine.Insert(data)
|
||||
|
@ -95,7 +95,7 @@ func TestQueryInterface(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(GetVarInterface)))
|
||||
|
||||
var data = GetVarInterface{
|
||||
data := GetVarInterface{
|
||||
Msg: "hi",
|
||||
Age: 28,
|
||||
Money: 1.5,
|
||||
|
@ -128,7 +128,7 @@ func TestQueryNoParams(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(QueryNoParams)))
|
||||
|
||||
var q = QueryNoParams{
|
||||
q := QueryNoParams{
|
||||
Msg: "message",
|
||||
Age: 20,
|
||||
Money: 3000,
|
||||
|
@ -172,7 +172,7 @@ func TestQueryStringNoParam(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(GetVar4)))
|
||||
|
||||
var data = GetVar4{
|
||||
data := GetVar4{
|
||||
Msg: false,
|
||||
}
|
||||
_, err := testEngine.Insert(data)
|
||||
|
@ -209,7 +209,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(GetVar6)))
|
||||
|
||||
var data = GetVar6{
|
||||
data := GetVar6{
|
||||
Msg: false,
|
||||
}
|
||||
_, err := testEngine.Insert(data)
|
||||
|
@ -246,7 +246,7 @@ func TestQueryInterfaceNoParam(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(GetVar5)))
|
||||
|
||||
var data = GetVar5{
|
||||
data := GetVar5{
|
||||
Msg: false,
|
||||
}
|
||||
_, err := testEngine.Insert(data)
|
||||
|
@ -280,7 +280,7 @@ func TestQueryWithBuilder(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(QueryWithBuilder)))
|
||||
|
||||
var q = QueryWithBuilder{
|
||||
q := QueryWithBuilder{
|
||||
Msg: "message",
|
||||
Age: 20,
|
||||
Money: 3000,
|
||||
|
@ -329,14 +329,14 @@ func TestJoinWithSubQuery(t *testing.T) {
|
|||
|
||||
assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart)))
|
||||
|
||||
var depart = JoinWithSubQueryDepart{
|
||||
depart := JoinWithSubQueryDepart{
|
||||
Name: "depart1",
|
||||
}
|
||||
cnt, err := testEngine.Insert(&depart)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
||||
var q = JoinWithSubQuery1{
|
||||
q := JoinWithSubQuery1{
|
||||
Msg: "message",
|
||||
DepartId: depart.Id,
|
||||
Money: 3000,
|
||||
|
@ -401,7 +401,7 @@ func TestQueryBLOBInMySQL(t *testing.T) {
|
|||
}
|
||||
|
||||
const N = 10
|
||||
var data = []Avatar{}
|
||||
data := []Avatar{}
|
||||
for i := 0; i < N; i++ {
|
||||
// allocate a []byte that is as twice big as the last one
|
||||
// so that the underlying buffer will need to reallocate when querying
|
||||
|
@ -448,3 +448,54 @@ func TestQueryBLOBInMySQL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowsReset(t *testing.T) {
|
||||
assert.NoError(t, PrepareEngine())
|
||||
|
||||
type RowsReset1 struct {
|
||||
Id int64
|
||||
Name string
|
||||
}
|
||||
|
||||
type RowsReset2 struct {
|
||||
Id int64
|
||||
Name string
|
||||
}
|
||||
|
||||
assert.NoError(t, testEngine.Sync(new(RowsReset1), new(RowsReset2)))
|
||||
|
||||
data := []RowsReset1{
|
||||
{0, "1"},
|
||||
{0, "2"},
|
||||
{0, "3"},
|
||||
}
|
||||
_, err := testEngine.Insert(data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
data2 := []RowsReset2{
|
||||
{0, "4"},
|
||||
{0, "5"},
|
||||
{0, "6"},
|
||||
}
|
||||
_, err = testEngine.Insert(data2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sess := testEngine.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
rows, err := sess.Rows(new(RowsReset1))
|
||||
assert.NoError(t, err)
|
||||
for rows.Next() {
|
||||
var data1 RowsReset1
|
||||
assert.NoError(t, rows.Scan(&data1))
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
var rrs []RowsReset2
|
||||
assert.NoError(t, sess.Find(&rrs))
|
||||
|
||||
assert.Len(t, rrs, 3)
|
||||
assert.EqualValues(t, "4", rrs[0].Name)
|
||||
assert.EqualValues(t, "5", rrs[1].Name)
|
||||
assert.EqualValues(t, "6", rrs[2].Name)
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
|||
}
|
||||
|
||||
buf := builder.NewWriter()
|
||||
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
|
||||
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return buf.String(), buf.Args(), nil
|
||||
|
@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
|
|||
}
|
||||
|
||||
buf := builder.NewWriter()
|
||||
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil {
|
||||
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return buf.String(), buf.Args(), nil
|
||||
|
@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
|||
}
|
||||
|
||||
buf := builder.NewWriter()
|
||||
if err := statement.writeSelect(buf, columnStr, true); err != nil {
|
||||
if err := statement.writeSelect(buf, columnStr, true, true); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return buf.String(), buf.Args(), nil
|
||||
|
@ -153,12 +153,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
|||
selectSQL = "count(*)"
|
||||
}
|
||||
}
|
||||
var subQuerySelect string
|
||||
if statement.GroupByStr != "" {
|
||||
subQuerySelect = statement.GroupByStr
|
||||
} else {
|
||||
subQuerySelect = selectSQL
|
||||
}
|
||||
|
||||
buf := builder.NewWriter()
|
||||
if statement.GroupByStr != "" {
|
||||
|
@ -167,7 +161,14 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
|||
}
|
||||
}
|
||||
|
||||
if err := statement.writeSelect(buf, subQuerySelect, false); err != nil {
|
||||
var subQuerySelect string
|
||||
if statement.GroupByStr != "" {
|
||||
subQuerySelect = statement.GroupByStr
|
||||
} else {
|
||||
subQuerySelect = selectSQL
|
||||
}
|
||||
|
||||
if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
@ -364,7 +365,7 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s
|
|||
return err
|
||||
}
|
||||
|
||||
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error {
|
||||
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error {
|
||||
if err := statement.writeSelectColumns(buf, columnStr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -380,8 +381,10 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
|
|||
if err := statement.writeHaving(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeOrderBys(buf); err != nil {
|
||||
return err
|
||||
if needOrderBy {
|
||||
if err := statement.writeOrderBys(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dialect := statement.dialect
|
||||
|
@ -519,7 +522,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
|
|||
statement.cond = statement.cond.And(autoCond)
|
||||
|
||||
buf := builder.NewWriter()
|
||||
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
|
||||
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return buf.String(), buf.Args(), nil
|
||||
|
|
|
@ -427,7 +427,158 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
|
|||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error {
|
||||
func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
|
||||
switch t := m.(type) {
|
||||
case map[string]interface{}:
|
||||
conds := []builder.Cond{}
|
||||
for k, v := range t {
|
||||
conds = append(conds, builder.Eq{k: v})
|
||||
}
|
||||
return conds, nil
|
||||
case map[string]string:
|
||||
conds := []builder.Cond{}
|
||||
for k, v := range t {
|
||||
conds = append(conds, builder.Eq{k: v})
|
||||
}
|
||||
return conds, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported condition map type %v", t)
|
||||
}
|
||||
}
|
||||
|
||||
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
|
||||
if v.Type().Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
table := statement.RefTable
|
||||
if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
|
||||
return nil
|
||||
}
|
||||
|
||||
verValue, err := table.VersionColumn().ValueOfV(&v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if verValue == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||
for i, expr := range statement.IncrColumns {
|
||||
if i > 0 || hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.Arg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||
// for update action to like "column = column - ?"
|
||||
for i, expr := range statement.DecrColumns {
|
||||
if i > 0 || hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.Arg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
|
||||
// for update action to like "column = expression"
|
||||
for i, expr := range statement.ExprColumns {
|
||||
if i > 0 || hasPreviousSet {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
switch tp := expr.Arg.(type) {
|
||||
case string:
|
||||
if len(tp) == 0 {
|
||||
tp = "''"
|
||||
}
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
|
||||
return err
|
||||
}
|
||||
case *builder.Builder:
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.Arg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
previousLen := w.Len()
|
||||
for i, colName := range colNames {
|
||||
if i > 0 {
|
||||
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(w, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.Append(args...)
|
||||
|
||||
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -450,17 +601,16 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
|
|||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||
return err
|
||||
}
|
||||
for i, colName := range colNames {
|
||||
if i > 0 {
|
||||
if _, err := fmt.Fprint(updateWriter, ", "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := fmt.Fprint(updateWriter, colName); err != nil {
|
||||
return err
|
||||
}
|
||||
previousLen := updateWriter.Len()
|
||||
|
||||
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if no columns to be updated, return error
|
||||
if previousLen == updateWriter.Len() {
|
||||
return ErrNoColumnsTobeUpdated
|
||||
}
|
||||
updateWriter.Append(args...)
|
||||
|
||||
// write from
|
||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||
|
|
2
rows.go
2
rows.go
|
@ -144,6 +144,8 @@ func (rows *Rows) Close() error {
|
|||
defer rows.session.Close()
|
||||
}
|
||||
|
||||
defer rows.session.resetStatement()
|
||||
|
||||
if rows.rows != nil {
|
||||
return rows.rows.Close()
|
||||
}
|
||||
|
|
|
@ -5,17 +5,17 @@
|
|||
package xorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/internal/statements"
|
||||
"xorm.io/xorm/internal/utils"
|
||||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
// enumerated all errors
|
||||
var (
|
||||
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||
ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
|
||||
)
|
||||
|
||||
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
||||
|
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
v := utils.ReflectValue(bean)
|
||||
t := v.Type()
|
||||
|
||||
var colNames []string
|
||||
var args []interface{}
|
||||
|
||||
// handle before update processors
|
||||
for _, closure := range session.beforeClosures {
|
||||
closure(bean)
|
||||
|
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
}
|
||||
// --
|
||||
|
||||
var colNames []string
|
||||
var args []interface{}
|
||||
var err error
|
||||
isMap := t.Kind() == reflect.Map
|
||||
isStruct := t.Kind() == reflect.Struct
|
||||
|
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
}
|
||||
}
|
||||
|
||||
// for update action to like "column = column + ?"
|
||||
incColumns := session.statement.IncrColumns
|
||||
for _, expr := range incColumns {
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
|
||||
args = append(args, expr.Arg)
|
||||
}
|
||||
// for update action to like "column = column - ?"
|
||||
decColumns := session.statement.DecrColumns
|
||||
for _, expr := range decColumns {
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
|
||||
args = append(args, expr.Arg)
|
||||
}
|
||||
// for update action to like "column = expression"
|
||||
exprColumns := session.statement.ExprColumns
|
||||
for _, expr := range exprColumns {
|
||||
switch tp := expr.Arg.(type) {
|
||||
case string:
|
||||
if len(tp) == 0 {
|
||||
tp = "''"
|
||||
}
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
|
||||
case *builder.Builder:
|
||||
subQuery, subArgs, err := builder.ToSQL(tp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
subQuery = session.statement.ReplaceQuote(subQuery)
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
|
||||
args = append(args, subArgs...)
|
||||
default:
|
||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
|
||||
args = append(args, expr.Arg)
|
||||
}
|
||||
}
|
||||
|
||||
if err = session.statement.ProcessIDParam(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
verValue *reflect.Value
|
||||
)
|
||||
if doIncVer {
|
||||
verValue, err = table.VersionColumn().ValueOf(bean)
|
||||
verValue, err = table.VersionColumn().ValueOfV(&v)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if verValue != nil {
|
||||
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
||||
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
|
||||
}
|
||||
}
|
||||
|
||||
if len(colNames) == 0 {
|
||||
return 0, ErrNoColumnsTobeUpdated
|
||||
}
|
||||
|
||||
updateWriter := builder.NewWriter()
|
||||
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil {
|
||||
if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue