diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 91a33319..187b94a3 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -43,8 +43,8 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - var hasInsertColumns = len(colNames) > 0 - var needSeq = len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) + hasInsertColumns := len(colNames) > 0 + needSeq := len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.ORACLE || statement.dialect.URI().DBType == schemas.DAMENG) if needSeq { for _, col := range colNames { if strings.EqualFold(col, table.AutoIncrement) { @@ -124,11 +124,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } - if _, err := buf.WriteString(" WHERE "); err != nil { - return "", nil, err - } - - if err := statement.Conds().WriteTo(buf); err != nil { + if err := statement.writeWhere(buf); err != nil { return "", nil, err } } else { diff --git a/internal/statements/query.go b/internal/statements/query.go index cea8be6d..2e38f0fe 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -244,6 +244,16 @@ func (statement *Statement) writeSelectColumns(w *builder.BytesWriter, columnStr } func (statement *Statement) writeWhere(w *builder.BytesWriter) error { + if !statement.cond.IsValid() { + return nil + } + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + return statement.cond.WriteTo(statement.QuoteReplacer(w)) +} + +func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { if !statement.cond.IsValid() { return statement.writeMssqlPaginationCond(w) } @@ -307,13 +317,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if err := statement.writeFrom(subWriter); err != nil { return err } - if statement.cond.IsValid() { - if _, err := fmt.Fprint(subWriter, " WHERE "); err != nil { - return err - } - if err := statement.cond.WriteTo(statement.QuoteReplacer(subWriter)); err != nil { - return err - } + if err := statement.writeWhere(subWriter); err != nil { + return err } if err := statement.WriteOrderBy(subWriter); err != nil { return err @@ -361,7 +366,7 @@ func (statement *Statement) writeSelect(buf *builder.BytesWriter, columnStr stri if err := statement.writeFrom(buf); err != nil { return err } - if err := statement.writeWhere(buf); err != nil { + if err := statement.writeWhereWithMssqlPagination(buf); err != nil { return err } if err := statement.writeGroupBy(buf); err != nil { @@ -427,13 +432,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if err := statement.writeJoins(buf); err != nil { return "", nil, err } - if statement.Conds().IsValid() { - if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { - return "", nil, err - } - if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { - return "", nil, err - } + if err := statement.writeWhere(buf); err != nil { + return "", nil, err } } else if statement.dialect.URI().DBType == schemas.ORACLE { if _, err := fmt.Fprintf(buf, "SELECT * FROM %s", tableName); err != nil { @@ -463,13 +463,8 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac if err := statement.writeJoins(buf); err != nil { return "", nil, err } - if statement.Conds().IsValid() { - if _, err := fmt.Fprintf(buf, " WHERE "); err != nil { - return "", nil, err - } - if err := statement.Conds().WriteTo(statement.QuoteReplacer(buf)); err != nil { - return "", nil, err - } + if err := statement.writeWhere(buf); err != nil { + return "", nil, err } if _, err := fmt.Fprintf(buf, " LIMIT 1"); err != nil { return "", nil, err diff --git a/internal/statements/update.go b/internal/statements/update.go index 4dc54780..16ab5676 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -9,8 +9,10 @@ import ( "errors" "fmt" "reflect" + "strings" "time" + "xorm.io/builder" "xorm.io/xorm/convert" "xorm.io/xorm/dialects" "xorm.io/xorm/internal/json" @@ -308,3 +310,85 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value, return colNames, args, nil } + +func (statement *Statement) WriteUpdate(updateWriter *builder.BytesWriter, cond builder.Cond, colNames []string) error { + whereWriter := builder.NewWriter() + if cond.IsValid() { + fmt.Fprint(whereWriter, "WHERE ") + } + if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + return err + } + if err := statement.WriteOrderBy(whereWriter); err != nil { + return err + } + + table := statement.RefTable + tableName := statement.TableName() + // TODO: Oracle support needed + var top string + if statement.LimitN != nil { + limitValue := *statement.LimitN + switch statement.dialect.URI().DBType { + case schemas.MYSQL: + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + case schemas.SQLITE: + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + + cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", + statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + return err + } + case schemas.POSTGRES: + fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) + + cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", + statement.quote(tableName), whereWriter.String()), whereWriter.Args()...)) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(statement.QuoteReplacer(whereWriter)); err != nil { + return err + } + case schemas.MSSQL: + if statement.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { + cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", + table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], + statement.quote(tableName), whereWriter.String()), whereWriter.Args()...) + + whereWriter = builder.NewWriter() + fmt.Fprint(whereWriter, "WHERE ") + if err := cond.WriteTo(whereWriter); err != nil { + return err + } + } else { + top = fmt.Sprintf("TOP (%d) ", limitValue) + } + } + } + + tableAlias := statement.quote(tableName) + var fromSQL string + if statement.TableAlias != "" { + switch statement.dialect.URI().DBType { + case schemas.MSSQL: + fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, statement.TableAlias) + tableAlias = statement.TableAlias + default: + tableAlias = fmt.Sprintf("%s AS %s", tableAlias, statement.TableAlias) + } + } + + if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", + top, + tableAlias, + strings.Join(colNames, ", "), + fromSQL); err != nil { + return err + } + return utils.WriteBuilder(updateWriter, whereWriter) +} diff --git a/session_update.go b/session_update.go index 1f80e70f..2b52a817 100644 --- a/session_update.go +++ b/session_update.go @@ -6,7 +6,6 @@ package xorm import ( "errors" - "fmt" "reflect" "strconv" "strings" @@ -142,6 +141,43 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri return nil } +func (session *Session) genAutoCond(condiBean interface{}) (builder.Cond, error) { + if session.statement.NoAutoCondition { + return builder.NewCond(), nil + } + + if c, ok := condiBean.(map[string]interface{}); ok { + eq := make(builder.Eq) + for k, v := range c { + eq[session.engine.Quote(k)] = v + } + autoCond := builder.Eq(eq) + + if session.statement.RefTable != nil { + if col := session.statement.RefTable.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled + return autoCond.And(session.statement.CondDeleted(col)), nil + } + } + return autoCond, nil + } + + ct := reflect.TypeOf(condiBean) + k := ct.Kind() + if k == reflect.Ptr { + k = ct.Elem().Kind() + } + if k == reflect.Struct { + condTable, err := session.engine.TableInfo(condiBean) + if err != nil { + return nil, err + } + + return session.statement.BuildConds(condTable, condiBean, true, true, false, true, false) + } + + return nil, ErrConditionType +} + // Update records, bean's non-empty fields are updated contents, // condiBean' non-empty filds are conditions // CAUTION: @@ -276,54 +312,14 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, err } - var autoCond builder.Cond - if !session.statement.NoAutoCondition { - condBeanIsStruct := false - if len(condiBean) > 0 { - if c, ok := condiBean[0].(map[string]interface{}); ok { - eq := make(builder.Eq) - for k, v := range c { - eq[session.engine.Quote(k)] = v - } - autoCond = builder.Eq(eq) - } else { - ct := reflect.TypeOf(condiBean[0]) - k := ct.Kind() - if k == reflect.Ptr { - k = ct.Elem().Kind() - } - if k == reflect.Struct { - condTable, err := session.engine.TableInfo(condiBean[0]) - if err != nil { - return 0, err - } - - autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false) - if err != nil { - return 0, err - } - condBeanIsStruct = true - } else { - return 0, ErrConditionType - } - } - } - - if !condBeanIsStruct && table != nil { - if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled - autoCond1 := session.statement.CondDeleted(col) - - if autoCond == nil { - autoCond = autoCond1 - } else { - autoCond = autoCond.And(autoCond1) - } - } + autoCond := builder.NewCond() + if len(condiBean) > 0 { + autoCond, err = session.genAutoCond(condiBean[0]) + if err != nil { + return 0, err } } - st := session.statement - var ( cond = session.statement.Conds().And(autoCond) doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.CheckVersion) @@ -345,85 +341,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 return 0, ErrNoColumnsTobeUpdated } - whereWriter := builder.NewWriter() - if cond.IsValid() { - fmt.Fprint(whereWriter, "WHERE ") - } - if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { - return 0, err - } - if err := st.WriteOrderBy(whereWriter); err != nil { - return 0, err - } - - tableName := session.statement.TableName() - // TODO: Oracle support needed - var top string - if st.LimitN != nil { - limitValue := *st.LimitN - switch session.engine.dialect.URI().DBType { - case schemas.MYSQL: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - case schemas.SQLITE: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - - cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", - session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { - return 0, err - } - case schemas.POSTGRES: - fmt.Fprintf(whereWriter, " LIMIT %d", limitValue) - - cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", - session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...)) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(st.QuoteReplacer(whereWriter)); err != nil { - return 0, err - } - case schemas.MSSQL: - if st.HasOrderBy() && table != nil && len(table.PrimaryKeys) == 1 { - cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)", - table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0], - session.engine.Quote(tableName), whereWriter.String()), whereWriter.Args()...) - - whereWriter = builder.NewWriter() - fmt.Fprint(whereWriter, "WHERE ") - if err := cond.WriteTo(whereWriter); err != nil { - return 0, err - } - } else { - top = fmt.Sprintf("TOP (%d) ", limitValue) - } - } - } - - tableAlias := session.engine.Quote(tableName) - var fromSQL string - if session.statement.TableAlias != "" { - switch session.engine.dialect.URI().DBType { - case schemas.MSSQL: - fromSQL = fmt.Sprintf("FROM %s %s ", tableAlias, session.statement.TableAlias) - tableAlias = session.statement.TableAlias - default: - tableAlias = fmt.Sprintf("%s AS %s", tableAlias, session.statement.TableAlias) - } - } - updateWriter := builder.NewWriter() - if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", - top, - tableAlias, - strings.Join(colNames, ", "), - fromSQL); err != nil { - return 0, err - } - if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil { + if err := session.statement.WriteUpdate(updateWriter, cond, colNames); err != nil { return 0, err } @@ -436,6 +355,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } + tableName := session.statement.TableName() if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { // session.cacheUpdate(table, tableName, sqlStr, args...) session.engine.logger.Debugf("[cache] clear table: %v", tableName)