Merge branch 'master' into lunny/update_join
This commit is contained in:
commit
77a1305ef6
|
@ -125,6 +125,11 @@ func TestWithTableName(t *testing.T) {
|
||||||
total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
|
total, err = testEngine.OrderBy("count(`id`) desc").Count(CountWithTableName{})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 2, total)
|
assert.EqualValues(t, 2, total)
|
||||||
|
|
||||||
|
// the orderby will be ignored by count because some databases will return errors if the orderby columns not in group by
|
||||||
|
total, err = testEngine.OrderBy("`name`").Count(CountWithTableName{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 2, total)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCountWithSelectCols(t *testing.T) {
|
func TestCountWithSelectCols(t *testing.T) {
|
||||||
|
|
|
@ -30,7 +30,7 @@ func TestQueryString(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(GetVar2)))
|
assert.NoError(t, testEngine.Sync(new(GetVar2)))
|
||||||
|
|
||||||
var data = GetVar2{
|
data := GetVar2{
|
||||||
Msg: "hi",
|
Msg: "hi",
|
||||||
Age: 28,
|
Age: 28,
|
||||||
Money: 1.5,
|
Money: 1.5,
|
||||||
|
@ -58,7 +58,7 @@ func TestQueryString2(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(GetVar3)))
|
assert.NoError(t, testEngine.Sync(new(GetVar3)))
|
||||||
|
|
||||||
var data = GetVar3{
|
data := GetVar3{
|
||||||
Msg: false,
|
Msg: false,
|
||||||
}
|
}
|
||||||
_, err := testEngine.Insert(data)
|
_, err := testEngine.Insert(data)
|
||||||
|
@ -95,7 +95,7 @@ func TestQueryInterface(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(GetVarInterface)))
|
assert.NoError(t, testEngine.Sync(new(GetVarInterface)))
|
||||||
|
|
||||||
var data = GetVarInterface{
|
data := GetVarInterface{
|
||||||
Msg: "hi",
|
Msg: "hi",
|
||||||
Age: 28,
|
Age: 28,
|
||||||
Money: 1.5,
|
Money: 1.5,
|
||||||
|
@ -128,7 +128,7 @@ func TestQueryNoParams(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(QueryNoParams)))
|
assert.NoError(t, testEngine.Sync(new(QueryNoParams)))
|
||||||
|
|
||||||
var q = QueryNoParams{
|
q := QueryNoParams{
|
||||||
Msg: "message",
|
Msg: "message",
|
||||||
Age: 20,
|
Age: 20,
|
||||||
Money: 3000,
|
Money: 3000,
|
||||||
|
@ -172,7 +172,7 @@ func TestQueryStringNoParam(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(GetVar4)))
|
assert.NoError(t, testEngine.Sync(new(GetVar4)))
|
||||||
|
|
||||||
var data = GetVar4{
|
data := GetVar4{
|
||||||
Msg: false,
|
Msg: false,
|
||||||
}
|
}
|
||||||
_, err := testEngine.Insert(data)
|
_, err := testEngine.Insert(data)
|
||||||
|
@ -209,7 +209,7 @@ func TestQuerySliceStringNoParam(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(GetVar6)))
|
assert.NoError(t, testEngine.Sync(new(GetVar6)))
|
||||||
|
|
||||||
var data = GetVar6{
|
data := GetVar6{
|
||||||
Msg: false,
|
Msg: false,
|
||||||
}
|
}
|
||||||
_, err := testEngine.Insert(data)
|
_, err := testEngine.Insert(data)
|
||||||
|
@ -246,7 +246,7 @@ func TestQueryInterfaceNoParam(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(GetVar5)))
|
assert.NoError(t, testEngine.Sync(new(GetVar5)))
|
||||||
|
|
||||||
var data = GetVar5{
|
data := GetVar5{
|
||||||
Msg: false,
|
Msg: false,
|
||||||
}
|
}
|
||||||
_, err := testEngine.Insert(data)
|
_, err := testEngine.Insert(data)
|
||||||
|
@ -280,7 +280,7 @@ func TestQueryWithBuilder(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(QueryWithBuilder)))
|
assert.NoError(t, testEngine.Sync(new(QueryWithBuilder)))
|
||||||
|
|
||||||
var q = QueryWithBuilder{
|
q := QueryWithBuilder{
|
||||||
Msg: "message",
|
Msg: "message",
|
||||||
Age: 20,
|
Age: 20,
|
||||||
Money: 3000,
|
Money: 3000,
|
||||||
|
@ -329,14 +329,14 @@ func TestJoinWithSubQuery(t *testing.T) {
|
||||||
|
|
||||||
assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart)))
|
assert.NoError(t, testEngine.Sync(new(JoinWithSubQuery1), new(JoinWithSubQueryDepart)))
|
||||||
|
|
||||||
var depart = JoinWithSubQueryDepart{
|
depart := JoinWithSubQueryDepart{
|
||||||
Name: "depart1",
|
Name: "depart1",
|
||||||
}
|
}
|
||||||
cnt, err := testEngine.Insert(&depart)
|
cnt, err := testEngine.Insert(&depart)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, cnt)
|
assert.EqualValues(t, 1, cnt)
|
||||||
|
|
||||||
var q = JoinWithSubQuery1{
|
q := JoinWithSubQuery1{
|
||||||
Msg: "message",
|
Msg: "message",
|
||||||
DepartId: depart.Id,
|
DepartId: depart.Id,
|
||||||
Money: 3000,
|
Money: 3000,
|
||||||
|
@ -401,7 +401,7 @@ func TestQueryBLOBInMySQL(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const N = 10
|
const N = 10
|
||||||
var data = []Avatar{}
|
data := []Avatar{}
|
||||||
for i := 0; i < N; i++ {
|
for i := 0; i < N; i++ {
|
||||||
// allocate a []byte that is as twice big as the last one
|
// allocate a []byte that is as twice big as the last one
|
||||||
// so that the underlying buffer will need to reallocate when querying
|
// so that the underlying buffer will need to reallocate when querying
|
||||||
|
@ -448,3 +448,54 @@ func TestQueryBLOBInMySQL(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRowsReset(t *testing.T) {
|
||||||
|
assert.NoError(t, PrepareEngine())
|
||||||
|
|
||||||
|
type RowsReset1 struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RowsReset2 struct {
|
||||||
|
Id int64
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, testEngine.Sync(new(RowsReset1), new(RowsReset2)))
|
||||||
|
|
||||||
|
data := []RowsReset1{
|
||||||
|
{0, "1"},
|
||||||
|
{0, "2"},
|
||||||
|
{0, "3"},
|
||||||
|
}
|
||||||
|
_, err := testEngine.Insert(data)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
data2 := []RowsReset2{
|
||||||
|
{0, "4"},
|
||||||
|
{0, "5"},
|
||||||
|
{0, "6"},
|
||||||
|
}
|
||||||
|
_, err = testEngine.Insert(data2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
sess := testEngine.NewSession()
|
||||||
|
defer sess.Close()
|
||||||
|
|
||||||
|
rows, err := sess.Rows(new(RowsReset1))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
for rows.Next() {
|
||||||
|
var data1 RowsReset1
|
||||||
|
assert.NoError(t, rows.Scan(&data1))
|
||||||
|
}
|
||||||
|
rows.Close()
|
||||||
|
|
||||||
|
var rrs []RowsReset2
|
||||||
|
assert.NoError(t, sess.Find(&rrs))
|
||||||
|
|
||||||
|
assert.Len(t, rrs, 3)
|
||||||
|
assert.EqualValues(t, "4", rrs[0].Name)
|
||||||
|
assert.EqualValues(t, "5", rrs[1].Name)
|
||||||
|
assert.EqualValues(t, "6", rrs[2].Name)
|
||||||
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := builder.NewWriter()
|
buf := builder.NewWriter()
|
||||||
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
|
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
return buf.String(), buf.Args(), nil
|
return buf.String(), buf.Args(), nil
|
||||||
|
@ -66,7 +66,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := builder.NewWriter()
|
buf := builder.NewWriter()
|
||||||
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true); err != nil {
|
if err := statement.writeSelect(buf, strings.Join(sumStrs, ", "), true, true); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
return buf.String(), buf.Args(), nil
|
return buf.String(), buf.Args(), nil
|
||||||
|
@ -122,7 +122,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := builder.NewWriter()
|
buf := builder.NewWriter()
|
||||||
if err := statement.writeSelect(buf, columnStr, true); err != nil {
|
if err := statement.writeSelect(buf, columnStr, true, true); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
return buf.String(), buf.Args(), nil
|
return buf.String(), buf.Args(), nil
|
||||||
|
@ -153,12 +153,6 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
selectSQL = "count(*)"
|
selectSQL = "count(*)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var subQuerySelect string
|
|
||||||
if statement.GroupByStr != "" {
|
|
||||||
subQuerySelect = statement.GroupByStr
|
|
||||||
} else {
|
|
||||||
subQuerySelect = selectSQL
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := builder.NewWriter()
|
buf := builder.NewWriter()
|
||||||
if statement.GroupByStr != "" {
|
if statement.GroupByStr != "" {
|
||||||
|
@ -167,7 +161,14 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := statement.writeSelect(buf, subQuerySelect, false); err != nil {
|
var subQuerySelect string
|
||||||
|
if statement.GroupByStr != "" {
|
||||||
|
subQuerySelect = statement.GroupByStr
|
||||||
|
} else {
|
||||||
|
subQuerySelect = selectSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeSelect(buf, subQuerySelect, false, false); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -364,7 +365,7 @@ func (statement *Statement) writeOracleLimit(w *builder.BytesWriter, columnStr s
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit bool) error {
|
func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr string, needLimit, needOrderBy bool) error {
|
||||||
if err := statement.writeSelectColumns(buf, columnStr); err != nil {
|
if err := statement.writeSelectColumns(buf, columnStr); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -380,9 +381,11 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
|
||||||
if err := statement.writeHaving(buf); err != nil {
|
if err := statement.writeHaving(buf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if needOrderBy {
|
||||||
if err := statement.writeOrderBys(buf); err != nil {
|
if err := statement.writeOrderBys(buf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dialect := statement.dialect
|
dialect := statement.dialect
|
||||||
if needLimit {
|
if needLimit {
|
||||||
|
@ -519,7 +522,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
|
||||||
statement.cond = statement.cond.And(autoCond)
|
statement.cond = statement.cond.And(autoCond)
|
||||||
|
|
||||||
buf := builder.NewWriter()
|
buf := builder.NewWriter()
|
||||||
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true); err != nil {
|
if err := statement.writeSelect(buf, statement.genSelectColumnStr(), true, true); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
return buf.String(), buf.Args(), nil
|
return buf.String(), buf.Args(), nil
|
||||||
|
|
|
@ -427,7 +427,158 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string, args []interface{}) error {
|
func (statement *Statement) GenConditionsFromMap(m interface{}) ([]builder.Cond, error) {
|
||||||
|
switch t := m.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
conds := []builder.Cond{}
|
||||||
|
for k, v := range t {
|
||||||
|
conds = append(conds, builder.Eq{k: v})
|
||||||
|
}
|
||||||
|
return conds, nil
|
||||||
|
case map[string]string:
|
||||||
|
conds := []builder.Cond{}
|
||||||
|
for k, v := range t {
|
||||||
|
conds = append(conds, builder.Eq{k: v})
|
||||||
|
}
|
||||||
|
return conds, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported condition map type %v", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Value, hasPreviousSet bool) error {
|
||||||
|
if v.Type().Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
table := statement.RefTable
|
||||||
|
if !(statement.RefTable != nil && table.Version != "" && statement.CheckVersion) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
verValue, err := table.VersionColumn().ValueOfV(&v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if verValue == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(table.Version), " = ", statement.quote(table.Version), " + 1"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||||
|
for i, expr := range statement.IncrColumns {
|
||||||
|
if i > 0 || hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " + ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(expr.Arg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error {
|
||||||
|
// for update action to like "column = column - ?"
|
||||||
|
for i, expr := range statement.DecrColumns {
|
||||||
|
if i > 0 || hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", statement.quote(expr.ColName), " - ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(expr.Arg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error {
|
||||||
|
// for update action to like "column = expression"
|
||||||
|
for i, expr := range statement.ExprColumns {
|
||||||
|
if i > 0 || hasPreviousSet {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch tp := expr.Arg.(type) {
|
||||||
|
case string:
|
||||||
|
if len(tp) == 0 {
|
||||||
|
tp = "''"
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ", tp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case *builder.Builder:
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ("); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tp.WriteTo(statement.QuoteReplacer(w)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, ")"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, err := fmt.Fprint(w, statement.quote(expr.ColName), " = ?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(expr.Arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error {
|
||||||
|
previousLen := w.Len()
|
||||||
|
for i, colName := range colNames {
|
||||||
|
if i > 0 {
|
||||||
|
if _, err := fmt.Fprint(w, ", "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, colName); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.Append(args...)
|
||||||
|
|
||||||
|
if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||||
|
|
||||||
|
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, v reflect.Value, colNames []string, args []interface{}) error {
|
||||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
if _, err := fmt.Fprintf(updateWriter, "UPDATE"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -450,17 +601,16 @@ func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond
|
||||||
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
if _, err := fmt.Fprint(updateWriter, " SET "); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i, colName := range colNames {
|
previousLen := updateWriter.Len()
|
||||||
if i > 0 {
|
|
||||||
if _, err := fmt.Fprint(updateWriter, ", "); err != nil {
|
if err := statement.writeUpdateSets(updateWriter, v, colNames, args); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if no columns to be updated, return error
|
||||||
|
if previousLen == updateWriter.Len() {
|
||||||
|
return ErrNoColumnsTobeUpdated
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprint(updateWriter, colName); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
updateWriter.Append(args...)
|
|
||||||
|
|
||||||
// write from
|
// write from
|
||||||
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
if err := statement.writeUpdateFrom(updateWriter); err != nil {
|
||||||
|
|
2
rows.go
2
rows.go
|
@ -144,6 +144,8 @@ func (rows *Rows) Close() error {
|
||||||
defer rows.session.Close()
|
defer rows.session.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer rows.session.resetStatement()
|
||||||
|
|
||||||
if rows.rows != nil {
|
if rows.rows != nil {
|
||||||
return rows.rows.Close()
|
return rows.rows.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,17 +5,17 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
|
"xorm.io/xorm/internal/statements"
|
||||||
"xorm.io/xorm/internal/utils"
|
"xorm.io/xorm/internal/utils"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
// enumerated all errors
|
// enumerated all errors
|
||||||
var (
|
var (
|
||||||
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
ErrNoColumnsTobeUpdated = statements.ErrNoColumnsTobeUpdated
|
||||||
)
|
)
|
||||||
|
|
||||||
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
||||||
|
@ -74,9 +74,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
v := utils.ReflectValue(bean)
|
v := utils.ReflectValue(bean)
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
var colNames []string
|
|
||||||
var args []interface{}
|
|
||||||
|
|
||||||
// handle before update processors
|
// handle before update processors
|
||||||
for _, closure := range session.beforeClosures {
|
for _, closure := range session.beforeClosures {
|
||||||
closure(bean)
|
closure(bean)
|
||||||
|
@ -87,6 +84,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
// --
|
// --
|
||||||
|
|
||||||
|
var colNames []string
|
||||||
|
var args []interface{}
|
||||||
var err error
|
var err error
|
||||||
isMap := t.Kind() == reflect.Map
|
isMap := t.Kind() == reflect.Map
|
||||||
isStruct := t.Kind() == reflect.Struct
|
isStruct := t.Kind() == reflect.Struct
|
||||||
|
@ -148,41 +147,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// for update action to like "column = column + ?"
|
|
||||||
incColumns := session.statement.IncrColumns
|
|
||||||
for _, expr := range incColumns {
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" + ?")
|
|
||||||
args = append(args, expr.Arg)
|
|
||||||
}
|
|
||||||
// for update action to like "column = column - ?"
|
|
||||||
decColumns := session.statement.DecrColumns
|
|
||||||
for _, expr := range decColumns {
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+" = "+session.engine.Quote(expr.ColName)+" - ?")
|
|
||||||
args = append(args, expr.Arg)
|
|
||||||
}
|
|
||||||
// for update action to like "column = expression"
|
|
||||||
exprColumns := session.statement.ExprColumns
|
|
||||||
for _, expr := range exprColumns {
|
|
||||||
switch tp := expr.Arg.(type) {
|
|
||||||
case string:
|
|
||||||
if len(tp) == 0 {
|
|
||||||
tp = "''"
|
|
||||||
}
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"="+tp)
|
|
||||||
case *builder.Builder:
|
|
||||||
subQuery, subArgs, err := builder.ToSQL(tp)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
subQuery = session.statement.ReplaceQuote(subQuery)
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=("+subQuery+")")
|
|
||||||
args = append(args, subArgs...)
|
|
||||||
default:
|
|
||||||
colNames = append(colNames, session.engine.Quote(expr.ColName)+"=?")
|
|
||||||
args = append(args, expr.Arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = session.statement.ProcessIDParam(); err != nil {
|
if err = session.statement.ProcessIDParam(); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -211,23 +175,18 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
verValue *reflect.Value
|
verValue *reflect.Value
|
||||||
)
|
)
|
||||||
if doIncVer {
|
if doIncVer {
|
||||||
verValue, err = table.VersionColumn().ValueOf(bean)
|
verValue, err = table.VersionColumn().ValueOfV(&v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if verValue != nil {
|
if verValue != nil {
|
||||||
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
||||||
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(colNames) == 0 {
|
|
||||||
return 0, ErrNoColumnsTobeUpdated
|
|
||||||
}
|
|
||||||
|
|
||||||
updateWriter := builder.NewWriter()
|
updateWriter := builder.NewWriter()
|
||||||
if err := session.statement.WriteUpdate(updateWriter, cond, colNames, args); err != nil {
|
if err := session.statement.WriteUpdate(updateWriter, cond, v, colNames, args); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue