diff --git a/engine.go b/engine.go index aa9d8050..0cbfdede 100644 --- a/engine.go +++ b/engine.go @@ -360,15 +360,15 @@ func (engine *Engine) NoAutoCondition(no ...bool) *Session { return session.NoAutoCondition(no...) } -func (engine *Engine) loadTableInfo(table *schemas.Table) error { - colSeq, cols, err := engine.dialect.GetColumns(engine.db, engine.defaultContext, table.Name) +func (engine *Engine) loadTableInfo(ctx context.Context, table *schemas.Table) error { + colSeq, cols, err := engine.dialect.GetColumns(engine.db, ctx, table.Name) if err != nil { return err } for _, name := range colSeq { table.AddColumn(cols[name]) } - indexes, err := engine.dialect.GetIndexes(engine.db, engine.defaultContext, table.Name) + indexes, err := engine.dialect.GetIndexes(engine.db, ctx, table.Name) if err != nil { return err } @@ -404,7 +404,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) { } for _, table := range tables { - if err = engine.loadTableInfo(table); err != nil { + if err = engine.loadTableInfo(engine.defaultContext, table); err != nil { return nil, err } } diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 91a33319..187b94a3 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -43,8 +43,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - var hasInsertColumns = len(colNames) > 0 - var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + hasInsertColumns := len(colNames) > 0 + needSeq := len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) if needSeq { for _, col := range colNames { if strings.EqualFold(col, table.AutoIncrement) { @@ -124,11 +124,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if _, err := buf.WriteString(" WHERE "); err != nil { - return "", nil, err - } - - if err := statement.Conds().WriteTo(buf); err != nil { + if err := statement.writeWhere(buf); err != nil { return "", nil, err } } else { diff --git a/internal/statements/query.go b/internal/statements/query.go index cea8be6d..2e38f0fe 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -244,6 +244,16 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr } func (statement *Statement) writeWhere(w *builder.BytesWriter) error { + if !statement.cond.IsValid() { + return nil + } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + return statement.cond.WriteTo(statement.QuoteReplacer(w)) +} + +func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { if !statement.cond.IsValid() { return statement.writeMssqlPaginationCond(w) } @@ -307,13 +317,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if err := statement.writeFrom(subWriter); err != nil { return err } - if statement.cond.IsValid() { - if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil { - return err - } - if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil { - return err - } + if err := statement.writeWhere(subWriter); err != nil { + return err } if err := statement.WriteOrderBy(subWriter); err != nil { return err @@ -361,7 +366,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri if err := statement.writeFrom(buf); err != nil { return err } - if err := statement.writeWhere(buf); err != nil { + if err := statement.writeWhereWithMssqlPagination(buf); err != nil { return err } if err := statement.writeGroupBy(buf); err != nil { @@ -427,13 +432,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if err := statement.writeJoins(buf); err != nil { return "", nil, err } - if statement.Conds().IsValid() { - if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { - return "", nil, err - } - if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { - return "", nil, err - } + if err := statement.writeWhere(buf); err != nil { + return "", nil, err } } else if statement.dialect.URI().DBType == schemas.ORACLE { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { @@ -463,13 +463,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if err := statement.writeJoins(buf); err != nil { return "", nil, err } - if statement.Conds().IsValid() { - if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { - return "", nil, err - } - if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { - return "", nil, err - } + if err := statement.writeWhere(buf); err != nil { + return "", nil, err } if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil { return "", nil, err diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 306676e5..61488ff7 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -299,13 +299,13 @@ func (statement *Statement) writeGroupBy(w builder.Writer) error { if statement.GroupByStr == "" { return nil } - _, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr) + _, err := fmt.Fprint(w, " GROUP BY ", statement.GroupByStr) return err } // Having generate "Having conditions" statement func (statement *Statement) Having(conditions string) *Statement { - statement.HavingStr = fmt.Sprintf("HAVING %v", statement.ReplaceQuote(conditions)) + statement.HavingStr = conditions return statement } @@ -313,7 +313,7 @@ func (statement *Statement) writeHaving(w builder.Writer) error { if statement.HavingStr == "" { return nil } - _, err := fmt.Fprint(w, " ", statement.HavingStr) + _, err := fmt.Fprint(w, " HAVING ", statement.ReplaceQuote(statement.HavingStr)) return err } diff --git a/internal/statements/update.go b/internal/statements/update.go index 4dc54780..16ab5676 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -9,8 +9,10 @@ import ( "errors" "fmt" "reflect" + "strings" "time" + "xorm.io/builder" "xorm.io/xorm/convert" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/json" @@ -308,3 +310,85 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, return colNames, args, nil } + +func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error { + whereWriter := builder.NewWriter() + if cond.IsValid() { + fmt.Fprint(whereWriter, "WHERE ") + } + if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + return err + } + if err := statement.WriteOrderBy(whereWriter); err != nil { + return err + } + + table := statement.RefTable + tableName := statement.TableName() + // TODO: Oracle support needed + var top string + if statement.LimitN != nil { + limitValue := *statement.LimitN + switch statement.dialect.URI().DBType { + case schemas.MYSQL: + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + case schemas.SQLITE: + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + + cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", + statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + return err + } + case schemas.POSTGRES: + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + + cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", + statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + return err + } + case schemas.MSSQL: + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], + statement.quote(tableName), whereWriter.String()), whereWriter.Args()...) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(whereWriter); err != nil { + return err + } + } else { + top = fmt.Sprintf("TOP (%d) ", limitValue) + } + } + } + + tableAlias := statement.quote(tableName) + var fromSQL string + if statement.TableAlias != "" { + switch statement.dialect.URI().DBType { + case schemas.MSSQL: + fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias) + tableAlias = statement.TableAlias + default: + tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias) + } + } + + if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", + top, + tableAlias, + strings.Join(colNames, ", "), + fromSQL); err != nil { + return err + } + return utils.WriteBuilder(updateWriter, whereWriter) +} diff --git a/session_stats.go b/session_stats.go index 5d0da5e9..be98e467 100644 --- a/session_stats.go +++ b/session_stats.go @@ -24,7 +24,7 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { var total int64 err = session.queryRow(sqlStr, args...).Scan(&total) - if err == sql.ErrNoRows || err == nil { + if err == nil { return total, nil } @@ -70,12 +70,12 @@ func (session *Session) SumInt(bean interface{}, columnName string) (res int64, // Sums call sum some columns. bean's non-empty fields are conditions. func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64, error) { - var res = make([]float64, len(columnNames)) + res := make([]float64, len(columnNames)) return res, session.sum(&res, bean, columnNames...) } // SumsInt sum specify columns and return as []int64 instead of []float64 func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int64, error) { - var res = make([]int64, len(columnNames)) + res := make([]int64, len(columnNames)) return res, session.sum(&res, bean, columnNames...) } diff --git a/session_update.go b/session_update.go index 1f80e70f..9a6964f1 100644 --- a/session_update.go +++ b/session_update.go @@ -6,13 +6,9 @@ package xorm import ( "errors" - "fmt" "reflect" - "strconv" - "strings" "xorm.io/builder" - "xorm.io/xorm/caches" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) @@ -22,124 +18,39 @@ var ( ErrNoColumnsTobeUpdated = errors.New("no columns found to be updated") ) -//revive:disable -func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr string, args ...interface{}) error { - if table == nil || - session.tx != nil { - return ErrCacheFailed +func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { + if session.statement.NoAutoCondition { + return builder.NewCond(), nil } - oldhead, newsql := session.statement.ConvertUpdateSQL(sqlStr) - if newsql == "" { - return ErrCacheFailed - } - for _, filter := range session.engine.dialect.Filters() { - newsql = filter.Do(session.ctx, newsql) - } - session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) - - var nStart int - if len(args) > 0 { - if strings.Contains(sqlStr, "?") { - nStart = strings.Count(oldhead, "?") - } else { - // only for pq, TODO: if any other databse? - nStart = strings.Count(oldhead, "$") + if c, ok := condiBean.(map[string]interface{}); ok { + eq := make(builder.Eq) + for k, v := range c { + eq[session.engine.Quote(k)] = v } + + if session.statement.RefTable != nil { + if col := session.statement.RefTable.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + return eq.And(session.statement.CondDeleted(col)), nil + } + } + return eq, nil } - cacher := session.engine.GetCacher(tableName) - session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:]) - ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:]) + ct := reflect.TypeOf(condiBean) + k := ct.Kind() + if k == reflect.Ptr { + k = ct.Elem().Kind() + } + if k != reflect.Struct { + return nil, ErrConditionType + } + + condTable, err := session.engine.TableInfo(condiBean) if err != nil { - rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) - if err != nil { - return err - } - defer rows.Close() - - ids = make([]schemas.PK, 0) - for rows.Next() { - res := make([]string, len(table.PrimaryKeys)) - err = rows.ScanSlice(&res) - if err != nil { - return err - } - var pk schemas.PK = make([]interface{}, len(table.PrimaryKeys)) - for i, col := range table.PKColumns() { - if col.SQLType.IsNumeric() { - n, err := strconv.ParseInt(res[i], 10, 64) - if err != nil { - return err - } - pk[i] = n - } else if col.SQLType.IsText() { - pk[i] = res[i] - } else { - return errors.New("not supported") - } - } - - ids = append(ids, pk) - } - if rows.Err() != nil { - return rows.Err() - } - session.engine.logger.Debugf("[cache] find updated id: %v", ids) - } /*else { - session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) - cacher.DelIds(tableName, genSqlKey(newsql, args)) - }*/ - - for _, id := range ids { - sid, err := id.ToString() - if err != nil { - return err - } - if bean := cacher.GetBean(tableName, sid); bean != nil { - sqls := utils.SplitNNoCase(sqlStr, "where", 2) - if len(sqls) == 0 || len(sqls) > 2 { - return ErrCacheFailed - } - - sqls = utils.SplitNNoCase(sqls[0], "set", 2) - if len(sqls) != 2 { - return ErrCacheFailed - } - kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") - - for idx, kv := range kvs { - sps := strings.SplitN(kv, "=", 2) - sps2 := strings.Split(sps[0], ".") - colName := sps2[len(sps2)-1] - colName = session.engine.dialect.Quoter().Trim(colName) - colName = schemas.CommonQuoter.Trim(colName) - - if col := table.GetColumn(colName); col != nil { - fieldValue, err := col.ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } else { - session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface()) - if col.IsVersion && session.statement.CheckVersion { - session.incrVersionFieldValue(fieldValue) - } else { - fieldValue.Set(reflect.ValueOf(args[idx])) - } - } - } else { - session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's", - colName, table.Name) - } - } - - session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean) - cacher.PutBean(tableName, sid, bean) - } + return nil, err } - session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName) - cacher.ClearIds(tableName) - return nil + return session.statement.BuildConds(condTable, condiBean, true, true, false, true, false) } // Update records, bean's non-empty fields are updated contents, @@ -277,53 +188,23 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } var autoCond builder.Cond - if !session.statement.NoAutoCondition { - condBeanIsStruct := false - if len(condiBean) > 0 { - if c, ok := condiBean[0].(map[string]interface{}); ok { - eq := make(builder.Eq) - for k, v := range c { - eq[session.engine.Quote(k)] = v - } - autoCond = builder.Eq(eq) - } else { - ct := reflect.TypeOf(condiBean[0]) - k := ct.Kind() - if k == reflect.Ptr { - k = ct.Elem().Kind() - } - if k == reflect.Struct { - condTable, err := session.engine.TableInfo(condiBean[0]) - if err != nil { - return 0, err - } - - autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) - if err != nil { - return 0, err - } - condBeanIsStruct = true - } else { - return 0, ErrConditionType - } - } + if len(condiBean) > 0 { + autoCond, err = session.genAutoCond(condiBean[0]) + if err != nil { + return 0, err } + } else if table != nil { + if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + autoCond1 := session.statement.CondDeleted(col) - if !condBeanIsStruct && table != nil { - if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled - autoCond1 := session.statement.CondDeleted(col) - - if autoCond == nil { - autoCond = autoCond1 - } else { - autoCond = autoCond.And(autoCond1) - } + if autoCond == nil { + autoCond = autoCond1 + } else { + autoCond = autoCond.And(autoCond1) } } } - st := session.statement - var ( cond = session.statement.Conds().And(autoCond) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) @@ -345,88 +226,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrNoColumnsTobeUpdated } - whereWriter := builder.NewWriter() - if cond.IsValid() { - fmt.Fprint(whereWriter, "WHERE ") - } - if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { - return 0, err - } - if err := st.WriteOrderBy(whereWriter); err != nil { - return 0, err - } - - tableName := session.statement.TableName() - // TODO: Oracle support needed - var top string - if st.LimitN != nil { - limitValue := *st.LimitN - switch session.engine.dialect.URI().DBType { - case schemas.MYSQL: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - case schemas.SQLITE: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - - cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { - return 0, err - } - case schemas.POSTGRES: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - - cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { - return 0, err - } - case schemas.MSSQL: - if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { - cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], - session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(whereWriter); err != nil { - return 0, err - } - } else { - top = fmt.Sprintf("TOP (%d) ", limitValue) - } - } - } - - tableAlias := session.engine.Quote(tableName) - var fromSQL string - if session.statement.TableAlias != "" { - switch session.engine.dialect.URI().DBType { - case schemas.MSSQL: - fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) - tableAlias = session.statement.TableAlias - default: - tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias) - } - } - updateWriter := builder.NewWriter() - if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", - top, - tableAlias, - strings.Join(colNames, ", "), - fromSQL); err != nil { - return 0, err - } - if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil { + if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil { return 0, err } + tableName := session.statement.TableName() // table name must been get before exec because statement will be reset + useCache := session.statement.UseCache + res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) if err != nil { return 0, err @@ -436,8 +243,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { - // session.cacheUpdate(table, tableName, sqlStr, args...) + if cacher := session.engine.GetCacher(tableName); cacher != nil && useCache { session.engine.logger.Debugf("[cache] clear table: %v", tableName) cacher.ClearIds(tableName) cacher.ClearBeans(tableName) diff --git a/sync.go b/sync.go index 11e75404..635a8ba9 100644 --- a/sync.go +++ b/sync.go @@ -116,7 +116,7 @@ func (session *Session) SyncWithOptions(opts SyncOptions, beans ...interface{}) } // this will modify an old table - if err = engine.loadTableInfo(oriTable); err != nil { + if err = engine.loadTableInfo(session.ctx, oriTable); err != nil { return nil, err }