some refactors

This commit is contained in:
Lunny Xiao 2023-07-13 13:20:34 +08:00
parent f33221df74
commit 6a783823fa
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
4 changed files with 148 additions and 153 deletions

View File

@ -43,8 +43,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err
}
var hasInsertColumns = len(colNames) > 0
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
hasInsertColumns := len(colNames) > 0
needSeq := len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if needSeq {
for _, col := range colNames {
if strings.EqualFold(col, table.AutoIncrement) {
@ -124,11 +124,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err
}
if _, err := buf.WriteString(" WHERE "); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(buf); err != nil {
if err := statement.writeWhere(buf); err != nil {
return "", nil, err
}
} else {

View File

@ -244,6 +244,16 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
}
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
return nil
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
return statement.cond.WriteTo(statement.QuoteReplacer(w))
}
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
return statement.writeMssqlPaginationCond(w)
}
@ -307,14 +317,9 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err
if err := statement.writeFrom(subWriter); err != nil {
return err
}
if statement.cond.IsValid() {
if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil {
if err := statement.writeWhere(subWriter); err != nil {
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil {
return err
}
}
if err := statement.WriteOrderBy(subWriter); err != nil {
return err
}
@ -361,7 +366,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
if err := statement.writeFrom(buf); err != nil {
return err
}
if err := statement.writeWhere(buf); err != nil {
if err := statement.writeWhereWithMssqlPagination(buf); err != nil {
return err
}
if err := statement.writeGroupBy(buf); err != nil {
@ -427,14 +432,9 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if err := statement.writeJoins(buf); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
if err := statement.writeWhere(buf); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
}
}
} else if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
return "", nil, err
@ -463,14 +463,9 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if err := statement.writeJoins(buf); err != nil {
return "", nil, err
}
if statement.Conds().IsValid() {
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
if err := statement.writeWhere(buf); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
}
}
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
return "", nil, err
}

View File

@ -9,8 +9,10 @@ import (
"errors"
"fmt"
"reflect"
"strings"
"time"
"xorm.io/builder"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json"
@ -308,3 +310,85 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
return colNames, args, nil
}
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
whereWriter := builder.NewWriter()
if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
if err := statement.WriteOrderBy(whereWriter); err != nil {
return err
}
table := statement.RefTable
tableName := statement.TableName()
// TODO: Oracle support needed
var top string
if statement.LimitN != nil {
limitValue := *statement.LimitN
switch statement.dialect.URI().DBType {
case schemas.MYSQL:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
case schemas.POSTGRES:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
case schemas.MSSQL:
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return err
}
} else {
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
tableAlias := statement.quote(tableName)
var fromSQL string
if statement.TableAlias != "" {
switch statement.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
tableAlias = statement.TableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
}
}
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top,
tableAlias,
strings.Join(colNames, ", "),
fromSQL); err != nil {
return err
}
return utils.WriteBuilder(updateWriter, whereWriter)
}

View File

@ -6,7 +6,6 @@ package xorm
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
@ -142,6 +141,43 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri
return nil
}
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
if session.statement.NoAutoCondition {
return builder.NewCond(), nil
}
if c, ok := condiBean.(map[string]interface{}); ok {
eq := make(builder.Eq)
for k, v := range c {
eq[session.engine.Quote(k)] = v
}
autoCond := builder.Eq(eq)
if session.statement.RefTable != nil {
if col := session.statement.RefTable.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
return autoCond.And(session.statement.CondDeleted(col)), nil
}
}
return autoCond, nil
}
ct := reflect.TypeOf(condiBean)
k := ct.Kind()
if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k == reflect.Struct {
condTable, err := session.engine.TableInfo(condiBean)
if err != nil {
return nil, err
}
return session.statement.BuildConds(condTable, condiBean, true, true, false, true, false)
}
return nil, ErrConditionType
}
// Update records, bean's non-empty fields are updated contents,
// condiBean' non-empty filds are conditions
// CAUTION:
@ -276,53 +312,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err
}
var autoCond builder.Cond
if !session.statement.NoAutoCondition {
condBeanIsStruct := false
autoCond := builder.NewCond()
if len(condiBean) > 0 {
if c, ok := condiBean[0].(map[string]interface{}); ok {
eq := make(builder.Eq)
for k, v := range c {
eq[session.engine.Quote(k)] = v
}
autoCond = builder.Eq(eq)
} else {
ct := reflect.TypeOf(condiBean[0])
k := ct.Kind()
if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k == reflect.Struct {
condTable, err := session.engine.TableInfo(condiBean[0])
autoCond, err = session.genAutoCond(condiBean[0])
if err != nil {
return 0, err
}
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
condBeanIsStruct = true
} else {
return 0, ErrConditionType
}
}
}
if !condBeanIsStruct && table != nil {
if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
autoCond1 := session.statement.CondDeleted(col)
if autoCond == nil {
autoCond = autoCond1
} else {
autoCond = autoCond.And(autoCond1)
}
}
}
}
st := session.statement
var (
cond = session.statement.Conds().And(autoCond)
@ -345,85 +341,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrNoColumnsTobeUpdated
}
whereWriter := builder.NewWriter()
if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
if err := st.WriteOrderBy(whereWriter); err != nil {
return 0, err
}
tableName := session.statement.TableName()
// TODO: Oracle support needed
var top string
if st.LimitN != nil {
limitValue := *st.LimitN
switch session.engine.dialect.URI().DBType {
case schemas.MYSQL:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
case schemas.POSTGRES:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
case schemas.MSSQL:
if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
} else {
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
tableAlias := session.engine.Quote(tableName)
var fromSQL string
if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias)
}
}
updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top,
tableAlias,
strings.Join(colNames, ", "),
fromSQL); err != nil {
return 0, err
}
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil {
return 0, err
}
@ -436,6 +355,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}
tableName := session.statement.TableName()
if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache {
// session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debugf("[cache] clear table: %v", tableName)