some refactors
This commit is contained in:
parent
f33221df74
commit
6a783823fa
|
@ -43,8 +43,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
var hasInsertColumns = len(colNames) > 0
|
||||
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||
hasInsertColumns := len(colNames) > 0
|
||||
needSeq := len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||
if needSeq {
|
||||
for _, col := range colNames {
|
||||
if strings.EqualFold(col, table.AutoIncrement) {
|
||||
|
@ -124,11 +124,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.WriteString(" WHERE "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if err := statement.Conds().WriteTo(buf); err != nil {
|
||||
if err := statement.writeWhere(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -244,6 +244,16 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
|
|||
}
|
||||
|
||||
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
|
||||
if !statement.cond.IsValid() {
|
||||
return nil
|
||||
}
|
||||
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
return statement.cond.WriteTo(statement.QuoteReplacer(w))
|
||||
}
|
||||
|
||||
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
|
||||
if !statement.cond.IsValid() {
|
||||
return statement.writeMssqlPaginationCond(w)
|
||||
}
|
||||
|
@ -307,13 +317,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err
|
|||
if err := statement.writeFrom(subWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
if statement.cond.IsValid() {
|
||||
if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeWhere(subWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.WriteOrderBy(subWriter); err != nil {
|
||||
return err
|
||||
|
@ -361,7 +366,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
|
|||
if err := statement.writeFrom(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeWhere(buf); err != nil {
|
||||
if err := statement.writeWhereWithMssqlPagination(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.writeGroupBy(buf); err != nil {
|
||||
|
@ -427,13 +432,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
|||
if err := statement.writeJoins(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if statement.Conds().IsValid() {
|
||||
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.writeWhere(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
} else if statement.dialect.URI().DBType == schemas.ORACLE {
|
||||
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
|
||||
|
@ -463,13 +463,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
|||
if err := statement.writeJoins(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if statement.Conds().IsValid() {
|
||||
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if err := statement.writeWhere(buf); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
|
||||
return "", nil, err
|
||||
|
|
|
@ -9,8 +9,10 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"xorm.io/builder"
|
||||
"xorm.io/xorm/convert"
|
||||
"xorm.io/xorm/dialects"
|
||||
"xorm.io/xorm/internal/json"
|
||||
|
@ -308,3 +310,85 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
|||
|
||||
return colNames, args, nil
|
||||
}
|
||||
|
||||
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
|
||||
whereWriter := builder.NewWriter()
|
||||
if cond.IsValid() {
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
}
|
||||
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := statement.WriteOrderBy(whereWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
table := statement.RefTable
|
||||
tableName := statement.TableName()
|
||||
// TODO: Oracle support needed
|
||||
var top string
|
||||
if statement.LimitN != nil {
|
||||
limitValue := *statement.LimitN
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.MYSQL:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
case schemas.SQLITE:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemas.POSTGRES:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||
return err
|
||||
}
|
||||
case schemas.MSSQL:
|
||||
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(whereWriter); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
top = fmt.Sprintf("TOP (%d) ", limitValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tableAlias := statement.quote(tableName)
|
||||
var fromSQL string
|
||||
if statement.TableAlias != "" {
|
||||
switch statement.dialect.URI().DBType {
|
||||
case schemas.MSSQL:
|
||||
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
|
||||
tableAlias = statement.TableAlias
|
||||
default:
|
||||
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
|
||||
top,
|
||||
tableAlias,
|
||||
strings.Join(colNames, ", "),
|
||||
fromSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
return utils.WriteBuilder(updateWriter, whereWriter)
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ package xorm
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -142,6 +141,43 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
|
|||
return nil
|
||||
}
|
||||
|
||||
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
||||
if session.statement.NoAutoCondition {
|
||||
return builder.NewCond(), nil
|
||||
}
|
||||
|
||||
if c, ok := condiBean.(map[string]interface{}); ok {
|
||||
eq := make(builder.Eq)
|
||||
for k, v := range c {
|
||||
eq[session.engine.Quote(k)] = v
|
||||
}
|
||||
autoCond := builder.Eq(eq)
|
||||
|
||||
if session.statement.RefTable != nil {
|
||||
if col := session.statement.RefTable.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
|
||||
return autoCond.And(session.statement.CondDeleted(col)), nil
|
||||
}
|
||||
}
|
||||
return autoCond, nil
|
||||
}
|
||||
|
||||
ct := reflect.TypeOf(condiBean)
|
||||
k := ct.Kind()
|
||||
if k == reflect.Ptr {
|
||||
k = ct.Elem().Kind()
|
||||
}
|
||||
if k == reflect.Struct {
|
||||
condTable, err := session.engine.TableInfo(condiBean)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.statement.BuildConds(condTable, condiBean, true, true, false, true, false)
|
||||
}
|
||||
|
||||
return nil, ErrConditionType
|
||||
}
|
||||
|
||||
// Update records, bean's non-empty fields are updated contents,
|
||||
// condiBean' non-empty filds are conditions
|
||||
// CAUTION:
|
||||
|
@ -276,54 +312,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
return 0, err
|
||||
}
|
||||
|
||||
var autoCond builder.Cond
|
||||
if !session.statement.NoAutoCondition {
|
||||
condBeanIsStruct := false
|
||||
if len(condiBean) > 0 {
|
||||
if c, ok := condiBean[0].(map[string]interface{}); ok {
|
||||
eq := make(builder.Eq)
|
||||
for k, v := range c {
|
||||
eq[session.engine.Quote(k)] = v
|
||||
}
|
||||
autoCond = builder.Eq(eq)
|
||||
} else {
|
||||
ct := reflect.TypeOf(condiBean[0])
|
||||
k := ct.Kind()
|
||||
if k == reflect.Ptr {
|
||||
k = ct.Elem().Kind()
|
||||
}
|
||||
if k == reflect.Struct {
|
||||
condTable, err := session.engine.TableInfo(condiBean[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
condBeanIsStruct = true
|
||||
} else {
|
||||
return 0, ErrConditionType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !condBeanIsStruct && table != nil {
|
||||
if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
|
||||
autoCond1 := session.statement.CondDeleted(col)
|
||||
|
||||
if autoCond == nil {
|
||||
autoCond = autoCond1
|
||||
} else {
|
||||
autoCond = autoCond.And(autoCond1)
|
||||
}
|
||||
}
|
||||
autoCond := builder.NewCond()
|
||||
if len(condiBean) > 0 {
|
||||
autoCond, err = session.genAutoCond(condiBean[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
st := session.statement
|
||||
|
||||
var (
|
||||
cond = session.statement.Conds().And(autoCond)
|
||||
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
|
||||
|
@ -345,85 +341,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
return 0, ErrNoColumnsTobeUpdated
|
||||
}
|
||||
|
||||
whereWriter := builder.NewWriter()
|
||||
if cond.IsValid() {
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
}
|
||||
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := st.WriteOrderBy(whereWriter); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tableName := session.statement.TableName()
|
||||
// TODO: Oracle support needed
|
||||
var top string
|
||||
if st.LimitN != nil {
|
||||
limitValue := *st.LimitN
|
||||
switch session.engine.dialect.URI().DBType {
|
||||
case schemas.MYSQL:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
case schemas.SQLITE:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case schemas.POSTGRES:
|
||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case schemas.MSSQL:
|
||||
if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
|
||||
|
||||
whereWriter = builder.NewWriter()
|
||||
fmt.Fprint(whereWriter, "WHERE ")
|
||||
if err := cond.WriteTo(whereWriter); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
top = fmt.Sprintf("TOP (%d) ", limitValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tableAlias := session.engine.Quote(tableName)
|
||||
var fromSQL string
|
||||
if session.statement.TableAlias != "" {
|
||||
switch session.engine.dialect.URI().DBType {
|
||||
case schemas.MSSQL:
|
||||
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
|
||||
tableAlias = session.statement.TableAlias
|
||||
default:
|
||||
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias)
|
||||
}
|
||||
}
|
||||
|
||||
updateWriter := builder.NewWriter()
|
||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
|
||||
top,
|
||||
tableAlias,
|
||||
strings.Join(colNames, ", "),
|
||||
fromSQL); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
|
||||
if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
@ -436,6 +355,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
}
|
||||
}
|
||||
|
||||
tableName := session.statement.TableName()
|
||||
if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache {
|
||||
// session.cacheUpdate(table, tableName, sqlStr, args...)
|
||||
session.engine.logger.Debugf("[cache] clear table: %v", tableName)
|
||||
|
|
Loading…
Reference in New Issue