From 9a7b4e7af526d67e8dbbdd8a8efb07f437d0aa9c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 4 Mar 2020 14:59:03 +0800 Subject: [PATCH] Move some codes to statement sub package --- engine.go | 4 - error.go | 4 - internal/statements/delete.go | 139 ++++++++++++++++++++++++++++++++++ internal/statements/query.go | 50 +++++------- session_convert.go | 3 +- session_delete.go | 122 +++++------------------------ 6 files changed, 177 insertions(+), 145 deletions(-) create mode 100644 internal/statements/delete.go diff --git a/engine.go b/engine.go index 221b7488..8b4f3931 100644 --- a/engine.go +++ b/engine.go @@ -1211,10 +1211,6 @@ func (engine *Engine) nowTime(col *schemas.Column) (interface{}, time.Time) { return dialects.FormatTime(engine.dialect, col.SQLType.Name, t.In(tz)), t.In(engine.TZLocation) } -func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interface{}) { - return dialects.FormatColumnTime(engine.dialect, engine.DatabaseTZ, col, t) -} - // GetColumnMapper returns the column name mapper func (engine *Engine) GetColumnMapper() names.Mapper { return engine.tagParser.GetColumnMapper() diff --git a/error.go b/error.go index a19860e3..21a83f47 100644 --- a/error.go +++ b/error.go @@ -20,10 +20,6 @@ var ( ErrNotExist = errors.New("Record does not exist") // ErrCacheFailed cache failed error ErrCacheFailed = errors.New("Cache failed") - // ErrNeedDeletedCond delete needs less one condition error - ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") - // ErrNotImplemented not implemented - ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") ) diff --git a/internal/statements/delete.go b/internal/statements/delete.go new file mode 100644 index 00000000..de4f9f0f --- /dev/null +++ b/internal/statements/delete.go @@ -0,0 +1,139 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package statements + +import ( + "errors" + "fmt" + "time" + + "xorm.io/xorm/dialects" + "xorm.io/xorm/schemas" +) + +var ( + // ErrNeedDeletedCond delete needs less one condition error + ErrNeedDeletedCond = errors.New("Delete action needs at least one condition") + + // ErrNotImplemented not implemented + ErrNotImplemented = errors.New("Not implemented") +) + +// GenDeleteSQL generated delete SQL according conditions +func (statement *Statement) GenDeleteSQL(bean interface{}) (string, string, []interface{}, error) { + condSQL, condArgs, err := statement.GenConds(bean) + if err != nil { + return "", "", nil, err + } + pLimitN := statement.LimitN + if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { + return "", "", nil, ErrNeedDeletedCond + } + + var tableNameNoQuote = statement.TableName() + var tableName = statement.quote(tableNameNoQuote) + var table = statement.RefTable + var deleteSQL string + if len(condSQL) > 0 { + deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) + } else { + deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) + } + + var orderSQL string + if len(statement.OrderStr) > 0 { + orderSQL += fmt.Sprintf(" ORDER BY %s", statement.OrderStr) + } + if pLimitN != nil && *pLimitN > 0 { + limitNValue := *pLimitN + orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) + } + + if len(orderSQL) > 0 { + switch statement.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + deleteSQL += " AND " + inSQL + } else { + deleteSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + deleteSQL += " AND " + inSQL + } else { + deleteSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return "", "", nil, ErrNotImplemented + default: + deleteSQL += orderSQL + } + } + + var realSQL string + argsForCache := make([]interface{}, 0, len(condArgs)*2) + if statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled + realSQL = deleteSQL + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + } else { + // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + + deletedColumn := table.DeletedColumn() + realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + statement.quote(statement.TableName()), + statement.quote(deletedColumn.Name), + condSQL) + + if len(orderSQL) > 0 { + switch statement.dialect.DBType() { + case schemas.POSTGRES: + inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + case schemas.SQLITE: + inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) + if len(condSQL) > 0 { + realSQL += " AND " + inSQL + } else { + realSQL += " WHERE " + inSQL + } + // TODO: how to handle delete limit on mssql? + case schemas.MSSQL: + return "", "", nil, ErrNotImplemented + default: + realSQL += orderSQL + } + } + + // !oinume! Insert nowTime to the head of statement.Params + condArgs = append(condArgs, "") + paramsLen := len(condArgs) + copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) + + now := ColumnNow(deletedColumn, statement.defaultTimeZone) + val := dialects.FormatTime(statement.dialect, deletedColumn.SQLType.Name, now) + condArgs[0] = val + } + return realSQL, deleteSQL, condArgs, nil +} + +// ColumnNow returns the current time for a column +func ColumnNow(col *schemas.Column, defaultTimeZone *time.Location) time.Time { + t := time.Now() + tz := defaultTimeZone + if !col.DisableTimeZone && col.TimeZone != nil { + tz = col.TimeZone + } + return t.In(tz) +} diff --git a/internal/statements/query.go b/internal/statements/query.go index 1519cb08..a058f752 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -57,16 +57,12 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, err } - condSQL, condArgs, err := builder.ToSQL(statement.cond) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } - args := append(statement.joinArgs, condArgs...) - sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) - if err != nil { - return "", nil, err - } + // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { @@ -92,12 +88,11 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - condSQL, condArgs, err := statement.GenConds(bean) - if err != nil { + if err := statement.mergeConds(bean); err != nil { return "", nil, err } - sqlStr, err := statement.GenSelectSQL(sumSelect, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) if err != nil { return "", nil, err } @@ -147,12 +142,8 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, return "", nil, err } } - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - sqlStr, err := statement.GenSelectSQL(columnStr, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } @@ -165,17 +156,13 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return statement.RawSQL, statement.RawParams, nil } - var condSQL string var condArgs []interface{} var err error if len(beans) > 0 { statement.SetRefBean(beans[0]) - condSQL, condArgs, err = statement.GenConds(beans[0]) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) - } - if err != nil { - return "", nil, err + if err := statement.mergeConds(beans[0]); err != nil { + return "", nil, err + } } var selectSQL = statement.SelectStr @@ -186,7 +173,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, err := statement.GenSelectSQL(selectSQL, condSQL, false, false) + sqlStr, condArgs, err := statement.genSelectSQL(selectSQL, false, false) if err != nil { return "", nil, err } @@ -194,7 +181,7 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { +func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { var ( distinct string dialect = statement.dialect @@ -205,6 +192,11 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " } + + condSQL, condArgs, err := builder.ToSQL(statement.cond) + if err != nil { + return "", nil, err + } if len(condSQL) > 0 { whereStr = " WHERE " + condSQL } @@ -313,10 +305,10 @@ func (statement *Statement) GenSelectSQL(columnStr, condSQL string, needLimit, n } } if statement.IsForUpdate { - return dialect.ForUpdateSQL(buf.String()), nil + return dialect.ForUpdateSQL(buf.String()), condArgs, nil } - return buf.String(), nil + return buf.String(), condArgs, nil } func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { @@ -428,16 +420,12 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa } statement.cond = statement.cond.And(autoCond) - condSQL, condArgs, err := builder.ToSQL(statement.cond) - if err != nil { - return "", nil, err - } - args = append(statement.joinArgs, condArgs...) - sqlStr, err = statement.GenSelectSQL(columnStr, condSQL, true, true) + sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } + args = append(statement.joinArgs, condArgs...) // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { diff --git a/session_convert.go b/session_convert.go index 1cd00627..0776bc45 100644 --- a/session_convert.go +++ b/session_convert.go @@ -15,6 +15,7 @@ import ( "time" "xorm.io/xorm/convert" + "xorm.io/xorm/dialects" "xorm.io/xorm/internal/json" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" @@ -583,7 +584,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. case reflect.Struct: if fieldType.ConvertibleTo(schemas.TimeType) { t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - tf := session.engine.formatColTime(col, t) + tf := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, t) return tf, nil } else if fieldType.ConvertibleTo(nullFloatType) { t := fieldValue.Convert(nullFloatType).Interface().(sql.NullFloat64) diff --git a/session_delete.go b/session_delete.go index 04200035..3373d89e 100644 --- a/session_delete.go +++ b/session_delete.go @@ -6,8 +6,8 @@ package xorm import ( "errors" - "fmt" "strconv" + "time" "xorm.io/xorm/caches" "xorm.io/xorm/schemas" @@ -98,119 +98,31 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - condSQL, condArgs, err := session.statement.GenConds(bean) + realSQL, deleteSQL, condArgs, err := session.statement.GenDeleteSQL(bean) if err != nil { return 0, err } - pLimitN := session.statement.LimitN - if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) { - return 0, ErrNeedDeletedCond + + argsForCache := make([]interface{}, 0, len(condArgs)*2) + copy(argsForCache, condArgs) + argsForCache = append(condArgs, argsForCache...) + + if !session.statement.GetUnscoped() && session.statement.RefTable.DeletedColumn() != nil { + deletedColumn := session.statement.RefTable.DeletedColumn() + + session.afterClosures = append(session.afterClosures, func(col *schemas.Column, tz *time.Location) func(interface{}) { + return func(bean interface{}) { + t := time.Now().In(tz) + setColumnTime(bean, col, t) + } + }(deletedColumn, session.engine.TZLocation)) } var tableNameNoQuote = session.statement.TableName() - var tableName = session.engine.Quote(tableNameNoQuote) - var table = session.statement.RefTable - var deleteSQL string - if len(condSQL) > 0 { - deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL) - } else { - deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName) - } - - var orderSQL string - if len(session.statement.OrderStr) > 0 { - orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr) - } - if pLimitN != nil && *pLimitN > 0 { - limitNValue := *pLimitN - orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue) - } - - if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL - } else { - deleteSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - deleteSQL += " AND " + inSQL - } else { - deleteSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return 0, ErrNotImplemented - default: - deleteSQL += orderSQL - } - } - - var realSQL string - argsForCache := make([]interface{}, 0, len(condArgs)*2) - if session.statement.GetUnscoped() || table.DeletedColumn() == nil { // tag "deleted" is disabled - realSQL = deleteSQL - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - } else { - // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches. - copy(argsForCache, condArgs) - argsForCache = append(condArgs, argsForCache...) - - deletedColumn := table.DeletedColumn() - realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", - session.engine.Quote(session.statement.TableName()), - session.engine.Quote(deletedColumn.Name), - condSQL) - - if len(orderSQL) > 0 { - switch session.engine.dialect.DBType() { - case schemas.POSTGRES: - inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - case schemas.SQLITE: - inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL) - if len(condSQL) > 0 { - realSQL += " AND " + inSQL - } else { - realSQL += " WHERE " + inSQL - } - // TODO: how to handle delete limit on mssql? - case schemas.MSSQL: - return 0, ErrNotImplemented - default: - realSQL += orderSQL - } - } - - // !oinume! Insert nowTime to the head of session.statement.Params - condArgs = append(condArgs, "") - paramsLen := len(condArgs) - copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1]) - - val, t := session.engine.nowTime(deletedColumn) - condArgs[0] = val - - var colName = deletedColumn.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) - } - if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache { - session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...) + session.cacheDelete(session.statement.RefTable, tableNameNoQuote, deleteSQL, argsForCache...) } - session.statement.RefTable = table res, err := session.exec(realSQL, condArgs...) if err != nil { return 0, err