commit
48cf8bb174
|
@ -360,15 +360,15 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session {
|
||||||
return session.NoAutoCondition(no...)
|
return session.NoAutoCondition(no...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (engine *Engine) loadTableInfo(table *schemas.Table) error {
|
func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) error {
|
||||||
colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name)
|
colSeq, cols, err := engine.dialect.GetColumns(engine.db, ctx, table.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, name := range colSeq {
|
for _, name := range colSeq {
|
||||||
table.AddColumn(cols[name])
|
table.AddColumn(cols[name])
|
||||||
}
|
}
|
||||||
indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name)
|
indexes, err := engine.dialect.GetIndexes(engine.db, ctx, table.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -404,7 +404,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
if err = engine.loadTableInfo(table); err != nil {
|
if err = engine.loadTableInfo(engine.defaultContext, table); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,8 +43,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasInsertColumns = len(colNames) > 0
|
hasInsertColumns := len(colNames) > 0
|
||||||
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
needSeq := len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
|
||||||
if needSeq {
|
if needSeq {
|
||||||
for _, col := range colNames {
|
for _, col := range colNames {
|
||||||
if strings.EqualFold(col, table.AutoIncrement) {
|
if strings.EqualFold(col, table.AutoIncrement) {
|
||||||
|
@ -124,11 +124,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := buf.WriteString(" WHERE "); err != nil {
|
if err := statement.writeWhere(buf); err != nil {
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := statement.Conds().WriteTo(buf); err != nil {
|
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -244,6 +244,16 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
|
func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
|
||||||
|
if !statement.cond.IsValid() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return statement.cond.WriteTo(statement.QuoteReplacer(w))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
|
||||||
if !statement.cond.IsValid() {
|
if !statement.cond.IsValid() {
|
||||||
return statement.writeMssqlPaginationCond(w)
|
return statement.writeMssqlPaginationCond(w)
|
||||||
}
|
}
|
||||||
|
@ -307,14 +317,9 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err
|
||||||
if err := statement.writeFrom(subWriter); err != nil {
|
if err := statement.writeFrom(subWriter); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if statement.cond.IsValid() {
|
if err := statement.writeWhere(subWriter); err != nil {
|
||||||
if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := statement.WriteOrderBy(subWriter); err != nil {
|
if err := statement.WriteOrderBy(subWriter); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -361,7 +366,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
|
||||||
if err := statement.writeFrom(buf); err != nil {
|
if err := statement.writeFrom(buf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := statement.writeWhere(buf); err != nil {
|
if err := statement.writeWhereWithMssqlPagination(buf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := statement.writeGroupBy(buf); err != nil {
|
if err := statement.writeGroupBy(buf); err != nil {
|
||||||
|
@ -427,14 +432,9 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
||||||
if err := statement.writeJoins(buf); err != nil {
|
if err := statement.writeJoins(buf); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if statement.Conds().IsValid() {
|
if err := statement.writeWhere(buf); err != nil {
|
||||||
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
|
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if statement.dialect.URI().DBType == schemas.ORACLE {
|
} else if statement.dialect.URI().DBType == schemas.ORACLE {
|
||||||
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
|
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
@ -463,14 +463,9 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
||||||
if err := statement.writeJoins(buf); err != nil {
|
if err := statement.writeJoins(buf); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if statement.Conds().IsValid() {
|
if err := statement.writeWhere(buf); err != nil {
|
||||||
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil {
|
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
|
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -299,13 +299,13 @@ func (statement *Statement) writeGroupBy(w builder.Writer) error {
|
||||||
if statement.GroupByStr == "" {
|
if statement.GroupByStr == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
_, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr)
|
_, err := fmt.Fprint(w, " GROUP BY ", statement.GroupByStr)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Having generate "Having conditions" statement
|
// Having generate "Having conditions" statement
|
||||||
func (statement *Statement) Having(conditions string) *Statement {
|
func (statement *Statement) Having(conditions string) *Statement {
|
||||||
statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions))
|
statement.HavingStr = conditions
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,7 +313,7 @@ func (statement *Statement) writeHaving(w builder.Writer) error {
|
||||||
if statement.HavingStr == "" {
|
if statement.HavingStr == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
_, err := fmt.Fprint(w, " ", statement.HavingStr)
|
_, err := fmt.Fprint(w, " HAVING ", statement.ReplaceQuote(statement.HavingStr))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,10 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"xorm.io/builder"
|
||||||
"xorm.io/xorm/convert"
|
"xorm.io/xorm/convert"
|
||||||
"xorm.io/xorm/dialects"
|
"xorm.io/xorm/dialects"
|
||||||
"xorm.io/xorm/internal/json"
|
"xorm.io/xorm/internal/json"
|
||||||
|
@ -308,3 +310,85 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
||||||
|
|
||||||
return colNames, args, nil
|
return colNames, args, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
|
||||||
|
whereWriter := builder.NewWriter()
|
||||||
|
if cond.IsValid() {
|
||||||
|
fmt.Fprint(whereWriter, "WHERE ")
|
||||||
|
}
|
||||||
|
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := statement.WriteOrderBy(whereWriter); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
table := statement.RefTable
|
||||||
|
tableName := statement.TableName()
|
||||||
|
// TODO: Oracle support needed
|
||||||
|
var top string
|
||||||
|
if statement.LimitN != nil {
|
||||||
|
limitValue := *statement.LimitN
|
||||||
|
switch statement.dialect.URI().DBType {
|
||||||
|
case schemas.MYSQL:
|
||||||
|
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||||
|
case schemas.SQLITE:
|
||||||
|
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||||
|
|
||||||
|
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||||
|
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||||
|
|
||||||
|
whereWriter = builder.NewWriter()
|
||||||
|
fmt.Fprint(whereWriter, "WHERE ")
|
||||||
|
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case schemas.POSTGRES:
|
||||||
|
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
||||||
|
|
||||||
|
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||||
|
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
||||||
|
|
||||||
|
whereWriter = builder.NewWriter()
|
||||||
|
fmt.Fprint(whereWriter, "WHERE ")
|
||||||
|
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case schemas.MSSQL:
|
||||||
|
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
||||||
|
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||||
|
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||||
|
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)
|
||||||
|
|
||||||
|
whereWriter = builder.NewWriter()
|
||||||
|
fmt.Fprint(whereWriter, "WHERE ")
|
||||||
|
if err := cond.WriteTo(whereWriter); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
top = fmt.Sprintf("TOP (%d) ", limitValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tableAlias := statement.quote(tableName)
|
||||||
|
var fromSQL string
|
||||||
|
if statement.TableAlias != "" {
|
||||||
|
switch statement.dialect.URI().DBType {
|
||||||
|
case schemas.MSSQL:
|
||||||
|
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
|
||||||
|
tableAlias = statement.TableAlias
|
||||||
|
default:
|
||||||
|
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
|
||||||
|
top,
|
||||||
|
tableAlias,
|
||||||
|
strings.Join(colNames, ", "),
|
||||||
|
fromSQL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return utils.WriteBuilder(updateWriter, whereWriter)
|
||||||
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
|
||||||
|
|
||||||
var total int64
|
var total int64
|
||||||
err = session.queryRow(sqlStr, args...).Scan(&total)
|
err = session.queryRow(sqlStr, args...).Scan(&total)
|
||||||
if err == sql.ErrNoRows || err == nil {
|
if err == nil {
|
||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,12 +70,12 @@ func (session *Session) SumInt(bean interface{}, columnName string) (res int64,
|
||||||
|
|
||||||
// Sums call sum some columns. bean's non-empty fields are conditions.
|
// Sums call sum some columns. bean's non-empty fields are conditions.
|
||||||
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
|
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
|
||||||
var res = make([]float64, len(columnNames))
|
res := make([]float64, len(columnNames))
|
||||||
return res, session.sum(&res, bean, columnNames...)
|
return res, session.sum(&res, bean, columnNames...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SumsInt sum specify columns and return as []int64 instead of []float64
|
// SumsInt sum specify columns and return as []int64 instead of []float64
|
||||||
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
|
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
|
||||||
var res = make([]int64, len(columnNames))
|
res := make([]int64, len(columnNames))
|
||||||
return res, session.sum(&res, bean, columnNames...)
|
return res, session.sum(&res, bean, columnNames...)
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,13 +6,9 @@ package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"xorm.io/builder"
|
"xorm.io/builder"
|
||||||
"xorm.io/xorm/caches"
|
|
||||||
"xorm.io/xorm/internal/utils"
|
"xorm.io/xorm/internal/utils"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
@ -22,124 +18,39 @@ var (
|
||||||
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
|
||||||
)
|
)
|
||||||
|
|
||||||
//revive:disable
|
func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
|
||||||
func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error {
|
if session.statement.NoAutoCondition {
|
||||||
if table == nil ||
|
return builder.NewCond(), nil
|
||||||
session.tx != nil {
|
|
||||||
return ErrCacheFailed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr)
|
if c, ok := condiBean.(map[string]interface{}); ok {
|
||||||
if newsql == "" {
|
eq := make(builder.Eq)
|
||||||
return ErrCacheFailed
|
for k, v := range c {
|
||||||
}
|
eq[session.engine.Quote(k)] = v
|
||||||
for _, filter := range session.engine.dialect.Filters() {
|
|
||||||
newsql = filter.Do(session.ctx, newsql)
|
|
||||||
}
|
|
||||||
session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)
|
|
||||||
|
|
||||||
var nStart int
|
|
||||||
if len(args) > 0 {
|
|
||||||
if strings.Contains(sqlStr, "?") {
|
|
||||||
nStart = strings.Count(oldhead, "?")
|
|
||||||
} else {
|
|
||||||
// only for pq, TODO: if any other databse?
|
|
||||||
nStart = strings.Count(oldhead, "$")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cacher := session.engine.GetCacher(tableName)
|
if session.statement.RefTable != nil {
|
||||||
session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:])
|
if col := session.statement.RefTable.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
|
||||||
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:])
|
return eq.And(session.statement.CondDeleted(col)), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return eq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ct := reflect.TypeOf(condiBean)
|
||||||
|
k := ct.Kind()
|
||||||
|
if k == reflect.Ptr {
|
||||||
|
k = ct.Elem().Kind()
|
||||||
|
}
|
||||||
|
if k != reflect.Struct {
|
||||||
|
return nil, ErrConditionType
|
||||||
|
}
|
||||||
|
|
||||||
|
condTable, err := session.engine.TableInfo(condiBean)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...)
|
return nil, err
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
return session.statement.BuildConds(condTable, condiBean, true, true, false, true, false)
|
||||||
|
|
||||||
ids = make([]schemas.PK, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
res := make([]string, len(table.PrimaryKeys))
|
|
||||||
err = rows.ScanSlice(&res)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
|
|
||||||
for i, col := range table.PKColumns() {
|
|
||||||
if col.SQLType.IsNumeric() {
|
|
||||||
n, err := strconv.ParseInt(res[i], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
pk[i] = n
|
|
||||||
} else if col.SQLType.IsText() {
|
|
||||||
pk[i] = res[i]
|
|
||||||
} else {
|
|
||||||
return errors.New("not supported")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ids = append(ids, pk)
|
|
||||||
}
|
|
||||||
if rows.Err() != nil {
|
|
||||||
return rows.Err()
|
|
||||||
}
|
|
||||||
session.engine.logger.Debugf("[cache] find updated id: %v", ids)
|
|
||||||
} /*else {
|
|
||||||
session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
|
|
||||||
cacher.DelIds(tableName, genSqlKey(newsql, args))
|
|
||||||
}*/
|
|
||||||
|
|
||||||
for _, id := range ids {
|
|
||||||
sid, err := id.ToString()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if bean := cacher.GetBean(tableName, sid); bean != nil {
|
|
||||||
sqls := utils.SplitNNoCase(sqlStr, "where", 2)
|
|
||||||
if len(sqls) == 0 || len(sqls) > 2 {
|
|
||||||
return ErrCacheFailed
|
|
||||||
}
|
|
||||||
|
|
||||||
sqls = utils.SplitNNoCase(sqls[0], "set", 2)
|
|
||||||
if len(sqls) != 2 {
|
|
||||||
return ErrCacheFailed
|
|
||||||
}
|
|
||||||
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
|
|
||||||
|
|
||||||
for idx, kv := range kvs {
|
|
||||||
sps := strings.SplitN(kv, "=", 2)
|
|
||||||
sps2 := strings.Split(sps[0], ".")
|
|
||||||
colName := sps2[len(sps2)-1]
|
|
||||||
colName = session.engine.dialect.Quoter().Trim(colName)
|
|
||||||
colName = schemas.CommonQuoter.Trim(colName)
|
|
||||||
|
|
||||||
if col := table.GetColumn(colName); col != nil {
|
|
||||||
fieldValue, err := col.ValueOf(bean)
|
|
||||||
if err != nil {
|
|
||||||
session.engine.logger.Errorf("%v", err)
|
|
||||||
} else {
|
|
||||||
session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface())
|
|
||||||
if col.IsVersion && session.statement.CheckVersion {
|
|
||||||
session.incrVersionFieldValue(fieldValue)
|
|
||||||
} else {
|
|
||||||
fieldValue.Set(reflect.ValueOf(args[idx]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's",
|
|
||||||
colName, table.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean)
|
|
||||||
cacher.PutBean(tableName, sid, bean)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName)
|
|
||||||
cacher.ClearIds(tableName)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update records, bean's non-empty fields are updated contents,
|
// Update records, bean's non-empty fields are updated contents,
|
||||||
|
@ -277,39 +188,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
|
|
||||||
var autoCond builder.Cond
|
var autoCond builder.Cond
|
||||||
if !session.statement.NoAutoCondition {
|
|
||||||
condBeanIsStruct := false
|
|
||||||
if len(condiBean) > 0 {
|
if len(condiBean) > 0 {
|
||||||
if c, ok := condiBean[0].(map[string]interface{}); ok {
|
autoCond, err = session.genAutoCond(condiBean[0])
|
||||||
eq := make(builder.Eq)
|
|
||||||
for k, v := range c {
|
|
||||||
eq[session.engine.Quote(k)] = v
|
|
||||||
}
|
|
||||||
autoCond = builder.Eq(eq)
|
|
||||||
} else {
|
|
||||||
ct := reflect.TypeOf(condiBean[0])
|
|
||||||
k := ct.Kind()
|
|
||||||
if k == reflect.Ptr {
|
|
||||||
k = ct.Elem().Kind()
|
|
||||||
}
|
|
||||||
if k == reflect.Struct {
|
|
||||||
condTable, err := session.engine.TableInfo(condiBean[0])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
} else if table != nil {
|
||||||
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
condBeanIsStruct = true
|
|
||||||
} else {
|
|
||||||
return 0, ErrConditionType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !condBeanIsStruct && table != nil {
|
|
||||||
if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
|
if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
|
||||||
autoCond1 := session.statement.CondDeleted(col)
|
autoCond1 := session.statement.CondDeleted(col)
|
||||||
|
|
||||||
|
@ -320,9 +204,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
st := session.statement
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
cond = session.statement.Conds().And(autoCond)
|
cond = session.statement.Conds().And(autoCond)
|
||||||
|
@ -345,88 +226,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
return 0, ErrNoColumnsTobeUpdated
|
return 0, ErrNoColumnsTobeUpdated
|
||||||
}
|
}
|
||||||
|
|
||||||
whereWriter := builder.NewWriter()
|
|
||||||
if cond.IsValid() {
|
|
||||||
fmt.Fprint(whereWriter, "WHERE ")
|
|
||||||
}
|
|
||||||
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if err := st.WriteOrderBy(whereWriter); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tableName := session.statement.TableName()
|
|
||||||
// TODO: Oracle support needed
|
|
||||||
var top string
|
|
||||||
if st.LimitN != nil {
|
|
||||||
limitValue := *st.LimitN
|
|
||||||
switch session.engine.dialect.URI().DBType {
|
|
||||||
case schemas.MYSQL:
|
|
||||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
|
||||||
case schemas.SQLITE:
|
|
||||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
|
||||||
|
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
|
||||||
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
|
||||||
|
|
||||||
whereWriter = builder.NewWriter()
|
|
||||||
fmt.Fprint(whereWriter, "WHERE ")
|
|
||||||
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
case schemas.POSTGRES:
|
|
||||||
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
|
|
||||||
|
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
|
||||||
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
|
|
||||||
|
|
||||||
whereWriter = builder.NewWriter()
|
|
||||||
fmt.Fprint(whereWriter, "WHERE ")
|
|
||||||
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
case schemas.MSSQL:
|
|
||||||
if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
|
|
||||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
|
||||||
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
|
||||||
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
|
|
||||||
|
|
||||||
whereWriter = builder.NewWriter()
|
|
||||||
fmt.Fprint(whereWriter, "WHERE ")
|
|
||||||
if err := cond.WriteTo(whereWriter); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
top = fmt.Sprintf("TOP (%d) ", limitValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tableAlias := session.engine.Quote(tableName)
|
|
||||||
var fromSQL string
|
|
||||||
if session.statement.TableAlias != "" {
|
|
||||||
switch session.engine.dialect.URI().DBType {
|
|
||||||
case schemas.MSSQL:
|
|
||||||
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
|
|
||||||
tableAlias = session.statement.TableAlias
|
|
||||||
default:
|
|
||||||
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
updateWriter := builder.NewWriter()
|
updateWriter := builder.NewWriter()
|
||||||
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
|
if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil {
|
||||||
top,
|
|
||||||
tableAlias,
|
|
||||||
strings.Join(colNames, ", "),
|
|
||||||
fromSQL); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
|
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tableName := session.statement.TableName() // table name must been get before exec because statement will be reset
|
||||||
|
useCache := session.statement.UseCache
|
||||||
|
|
||||||
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
|
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -436,8 +243,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache {
|
if cacher := session.engine.GetCacher(tableName); cacher != nil && useCache {
|
||||||
// session.cacheUpdate(table, tableName, sqlStr, args...)
|
|
||||||
session.engine.logger.Debugf("[cache] clear table: %v", tableName)
|
session.engine.logger.Debugf("[cache] clear table: %v", tableName)
|
||||||
cacher.ClearIds(tableName)
|
cacher.ClearIds(tableName)
|
||||||
cacher.ClearBeans(tableName)
|
cacher.ClearBeans(tableName)
|
||||||
|
|
2
sync.go
2
sync.go
|
@ -116,7 +116,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// this will modify an old table
|
// this will modify an old table
|
||||||
if err = engine.loadTableInfo(oriTable); err != nil {
|
if err = engine.loadTableInfo(session.ctx, oriTable); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue