Merge branch 'master'

# Conflicts:
#	convert/time.go
This commit is contained in:
CyJay 2023-07-21 09:07:56 +08:00
commit 48cf8bb174
8 changed files with 157 additions and 276 deletions

View File

@ -360,15 +360,15 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session {
return session.NoAutoCondition(no...) return session.NoAutoCondition(no...)
} }
func (engine *Engine) loadTableInfo(table *schemas.Table) error { func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) error {
colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name) colSeq, cols, err := engine.dialect.GetColumns(engine.db, ctx, table.Name)
if err != nil { if err != nil {
return err return err
} }
for _, name := range colSeq { for _, name := range colSeq {
table.AddColumn(cols[name]) table.AddColumn(cols[name])
} }
indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name) indexes, err := engine.dialect.GetIndexes(engine.db, ctx, table.Name)
if err != nil { if err != nil {
return err return err
} }
@ -404,7 +404,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) {
} }
for _, table := range tables { for _, table := range tables {
if err = engine.loadTableInfo(table); err != nil { if err = engine.loadTableInfo(engine.defaultContext, table); err != nil {
return nil, err return nil, err
} }
} }

View File

@ -43,8 +43,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err return "", nil, err
} }
var hasInsertColumns = len(colNames) > 0 hasInsertColumns := len(colNames) > 0
var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) needSeq := len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG)
if needSeq { if needSeq {
for _, col := range colNames { for _, col := range colNames {
if strings.EqualFold(col, table.AutoIncrement) { if strings.EqualFold(col, table.AutoIncrement) {
@ -124,11 +124,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err return "", nil, err
} }
if _, err := buf.WriteString(" WHERE "); err != nil { if err := statement.writeWhere(buf); err != nil {
return "", nil, err
}
if err := statement.Conds().WriteTo(buf); err != nil {
return "", nil, err return "", nil, err
} }
} else { } else {

View File

@ -244,6 +244,16 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr
} }
func (statement *Statement) writeWhere(w *builder.BytesWriter) error { func (statement *Statement) writeWhere(w *builder.BytesWriter) error {
if !statement.cond.IsValid() {
return nil
}
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
return err
}
return statement.cond.WriteTo(statement.QuoteReplacer(w))
}
func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error {
if !statement.cond.IsValid() { if !statement.cond.IsValid() {
return statement.writeMssqlPaginationCond(w) return statement.writeMssqlPaginationCond(w)
} }
@ -307,13 +317,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err
if err := statement.writeFrom(subWriter); err != nil { if err := statement.writeFrom(subWriter); err != nil {
return err return err
} }
if statement.cond.IsValid() { if err := statement.writeWhere(subWriter); err != nil {
if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil { return err
return err
}
if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil {
return err
}
} }
if err := statement.WriteOrderBy(subWriter); err != nil { if err := statement.WriteOrderBy(subWriter); err != nil {
return err return err
@ -361,7 +366,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri
if err := statement.writeFrom(buf); err != nil { if err := statement.writeFrom(buf); err != nil {
return err return err
} }
if err := statement.writeWhere(buf); err != nil { if err := statement.writeWhereWithMssqlPagination(buf); err != nil {
return err return err
} }
if err := statement.writeGroupBy(buf); err != nil { if err := statement.writeGroupBy(buf); err != nil {
@ -427,13 +432,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if err := statement.writeJoins(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if err := statement.writeWhere(buf); err != nil {
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { return "", nil, err
return "", nil, err
}
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
}
} }
} else if statement.dialect.URI().DBType == schemas.ORACLE { } else if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil {
@ -463,13 +463,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
if err := statement.writeJoins(buf); err != nil { if err := statement.writeJoins(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.Conds().IsValid() { if err := statement.writeWhere(buf); err != nil {
if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { return "", nil, err
return "", nil, err
}
if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil {
return "", nil, err
}
} }
if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil { if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil {
return "", nil, err return "", nil, err

View File

@ -299,13 +299,13 @@ func (statement *Statement) writeGroupBy(w builder.Writer) error {
if statement.GroupByStr == "" { if statement.GroupByStr == "" {
return nil return nil
} }
_, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr) _, err := fmt.Fprint(w, " GROUP BY ", statement.GroupByStr)
return err return err
} }
// Having generate "Having conditions" statement // Having generate "Having conditions" statement
func (statement *Statement) Having(conditions string) *Statement { func (statement *Statement) Having(conditions string) *Statement {
statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) statement.HavingStr = conditions
return statement return statement
} }
@ -313,7 +313,7 @@ func (statement *Statement) writeHaving(w builder.Writer) error {
if statement.HavingStr == "" { if statement.HavingStr == "" {
return nil return nil
} }
_, err := fmt.Fprint(w, " ", statement.HavingStr) _, err := fmt.Fprint(w, " HAVING ", statement.ReplaceQuote(statement.HavingStr))
return err return err
} }

View File

@ -9,8 +9,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
"time" "time"
"xorm.io/builder"
"xorm.io/xorm/convert" "xorm.io/xorm/convert"
"xorm.io/xorm/dialects" "xorm.io/xorm/dialects"
"xorm.io/xorm/internal/json" "xorm.io/xorm/internal/json"
@ -308,3 +310,85 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
return colNames, args, nil return colNames, args, nil
} }
func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error {
whereWriter := builder.NewWriter()
if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
if err := statement.WriteOrderBy(whereWriter); err != nil {
return err
}
table := statement.RefTable
tableName := statement.TableName()
// TODO: Oracle support needed
var top string
if statement.LimitN != nil {
limitValue := *statement.LimitN
switch statement.dialect.URI().DBType {
case schemas.MYSQL:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
case schemas.POSTGRES:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil {
return err
}
case schemas.MSSQL:
if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return err
}
} else {
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
tableAlias := statement.quote(tableName)
var fromSQL string
if statement.TableAlias != "" {
switch statement.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias)
tableAlias = statement.TableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias)
}
}
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top,
tableAlias,
strings.Join(colNames, ", "),
fromSQL); err != nil {
return err
}
return utils.WriteBuilder(updateWriter, whereWriter)
}

View File

@ -24,7 +24,7 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
var total int64 var total int64
err = session.queryRow(sqlStr, args...).Scan(&total) err = session.queryRow(sqlStr, args...).Scan(&total)
if err == sql.ErrNoRows || err == nil { if err == nil {
return total, nil return total, nil
} }
@ -70,12 +70,12 @@ func (session *Session) SumInt(bean interface{}, columnName string) (res int64,
// Sums call sum some columns. bean's non-empty fields are conditions. // Sums call sum some columns. bean's non-empty fields are conditions.
func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) {
var res = make([]float64, len(columnNames)) res := make([]float64, len(columnNames))
return res, session.sum(&res, bean, columnNames...) return res, session.sum(&res, bean, columnNames...)
} }
// SumsInt sum specify columns and return as []int64 instead of []float64 // SumsInt sum specify columns and return as []int64 instead of []float64
func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) {
var res = make([]int64, len(columnNames)) res := make([]int64, len(columnNames))
return res, session.sum(&res, bean, columnNames...) return res, session.sum(&res, bean, columnNames...)
} }

View File

@ -6,13 +6,9 @@ package xorm
import ( import (
"errors" "errors"
"fmt"
"reflect" "reflect"
"strconv"
"strings"
"xorm.io/builder" "xorm.io/builder"
"xorm.io/xorm/caches"
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
@ -22,124 +18,39 @@ var (
ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated")
) )
//revive:disable func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) {
func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { if session.statement.NoAutoCondition {
if table == nil || return builder.NewCond(), nil
session.tx != nil {
return ErrCacheFailed
} }
oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr) if c, ok := condiBean.(map[string]interface{}); ok {
if newsql == "" { eq := make(builder.Eq)
return ErrCacheFailed for k, v := range c {
} eq[session.engine.Quote(k)] = v
for _, filter := range session.engine.dialect.Filters() {
newsql = filter.Do(session.ctx, newsql)
}
session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql)
var nStart int
if len(args) > 0 {
if strings.Contains(sqlStr, "?") {
nStart = strings.Count(oldhead, "?")
} else {
// only for pq, TODO: if any other databse?
nStart = strings.Count(oldhead, "$")
} }
if session.statement.RefTable != nil {
if col := session.statement.RefTable.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
return eq.And(session.statement.CondDeleted(col)), nil
}
}
return eq, nil
} }
cacher := session.engine.GetCacher(tableName) ct := reflect.TypeOf(condiBean)
session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:]) k := ct.Kind()
ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k != reflect.Struct {
return nil, ErrConditionType
}
condTable, err := session.engine.TableInfo(condiBean)
if err != nil { if err != nil {
rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) return nil, err
if err != nil {
return err
}
defer rows.Close()
ids = make([]schemas.PK, 0)
for rows.Next() {
res := make([]string, len(table.PrimaryKeys))
err = rows.ScanSlice(&res)
if err != nil {
return err
}
var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys))
for i, col := range table.PKColumns() {
if col.SQLType.IsNumeric() {
n, err := strconv.ParseInt(res[i], 10, 64)
if err != nil {
return err
}
pk[i] = n
} else if col.SQLType.IsText() {
pk[i] = res[i]
} else {
return errors.New("not supported")
}
}
ids = append(ids, pk)
}
if rows.Err() != nil {
return rows.Err()
}
session.engine.logger.Debugf("[cache] find updated id: %v", ids)
} /*else {
session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args)
cacher.DelIds(tableName, genSqlKey(newsql, args))
}*/
for _, id := range ids {
sid, err := id.ToString()
if err != nil {
return err
}
if bean := cacher.GetBean(tableName, sid); bean != nil {
sqls := utils.SplitNNoCase(sqlStr, "where", 2)
if len(sqls) == 0 || len(sqls) > 2 {
return ErrCacheFailed
}
sqls = utils.SplitNNoCase(sqls[0], "set", 2)
if len(sqls) != 2 {
return ErrCacheFailed
}
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")
for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1]
colName = session.engine.dialect.Quoter().Trim(colName)
colName = schemas.CommonQuoter.Trim(colName)
if col := table.GetColumn(colName); col != nil {
fieldValue, err := col.ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else {
session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface())
if col.IsVersion && session.statement.CheckVersion {
session.incrVersionFieldValue(fieldValue)
} else {
fieldValue.Set(reflect.ValueOf(args[idx]))
}
}
} else {
session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's",
colName, table.Name)
}
}
session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean)
cacher.PutBean(tableName, sid, bean)
}
} }
session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName) return session.statement.BuildConds(condTable, condiBean, true, true, false, true, false)
cacher.ClearIds(tableName)
return nil
} }
// Update records, bean's non-empty fields are updated contents, // Update records, bean's non-empty fields are updated contents,
@ -277,53 +188,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
var autoCond builder.Cond var autoCond builder.Cond
if !session.statement.NoAutoCondition { if len(condiBean) > 0 {
condBeanIsStruct := false autoCond, err = session.genAutoCond(condiBean[0])
if len(condiBean) > 0 { if err != nil {
if c, ok := condiBean[0].(map[string]interface{}); ok { return 0, err
eq := make(builder.Eq)
for k, v := range c {
eq[session.engine.Quote(k)] = v
}
autoCond = builder.Eq(eq)
} else {
ct := reflect.TypeOf(condiBean[0])
k := ct.Kind()
if k == reflect.Ptr {
k = ct.Elem().Kind()
}
if k == reflect.Struct {
condTable, err := session.engine.TableInfo(condiBean[0])
if err != nil {
return 0, err
}
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
if err != nil {
return 0, err
}
condBeanIsStruct = true
} else {
return 0, ErrConditionType
}
}
} }
} else if table != nil {
if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled
autoCond1 := session.statement.CondDeleted(col)
if !condBeanIsStruct && table != nil { if autoCond == nil {
if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled autoCond = autoCond1
autoCond1 := session.statement.CondDeleted(col) } else {
autoCond = autoCond.And(autoCond1)
if autoCond == nil {
autoCond = autoCond1
} else {
autoCond = autoCond.And(autoCond1)
}
} }
} }
} }
st := session.statement
var ( var (
cond = session.statement.Conds().And(autoCond) cond = session.statement.Conds().And(autoCond)
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion)
@ -345,88 +226,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, ErrNoColumnsTobeUpdated return 0, ErrNoColumnsTobeUpdated
} }
whereWriter := builder.NewWriter()
if cond.IsValid() {
fmt.Fprint(whereWriter, "WHERE ")
}
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
if err := st.WriteOrderBy(whereWriter); err != nil {
return 0, err
}
tableName := session.statement.TableName()
// TODO: Oracle support needed
var top string
if st.LimitN != nil {
limitValue := *st.LimitN
switch session.engine.dialect.URI().DBType {
case schemas.MYSQL:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
case schemas.SQLITE:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
case schemas.POSTGRES:
fmt.Fprintf(whereWriter, " LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...))
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil {
return 0, err
}
case schemas.MSSQL:
if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)
whereWriter = builder.NewWriter()
fmt.Fprint(whereWriter, "WHERE ")
if err := cond.WriteTo(whereWriter); err != nil {
return 0, err
}
} else {
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
tableAlias := session.engine.Quote(tableName)
var fromSQL string
if session.statement.TableAlias != "" {
switch session.engine.dialect.URI().DBType {
case schemas.MSSQL:
fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias)
tableAlias = session.statement.TableAlias
default:
tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias)
}
}
updateWriter := builder.NewWriter() updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil {
top,
tableAlias,
strings.Join(colNames, ", "),
fromSQL); err != nil {
return 0, err
}
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
return 0, err return 0, err
} }
tableName := session.statement.TableName() // table name must been get before exec because statement will be reset
useCache := session.statement.UseCache
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
if err != nil { if err != nil {
return 0, err return 0, err
@ -436,8 +243,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
} }
if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { if cacher := session.engine.GetCacher(tableName); cacher != nil && useCache {
// session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debugf("[cache] clear table: %v", tableName) session.engine.logger.Debugf("[cache] clear table: %v", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
cacher.ClearBeans(tableName) cacher.ClearBeans(tableName)

View File

@ -116,7 +116,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{})
} }
// this will modify an old table // this will modify an old table
if err = engine.loadTableInfo(oriTable); err != nil { if err = engine.loadTableInfo(session.ctx, oriTable); err != nil {
return nil, err return nil, err
} }