Refactor orderby and support arguments
This commit is contained in:
parent
eeb7fcf22c
commit
ec7b41fd88
12
engine.go
12
engine.go
|
@ -380,7 +380,7 @@ func (engine *Engine) loadTableInfo(table *schemas.Table) error {
|
||||||
seq = 0
|
seq = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var colName = strings.Trim(parts[0], `"`)
|
colName := strings.Trim(parts[0], `"`)
|
||||||
if col := table.GetColumn(colName); col != nil {
|
if col := table.GetColumn(colName); col != nil {
|
||||||
col.Indexes[index.Name] = index.Type
|
col.Indexes[index.Name] = index.Type
|
||||||
} else {
|
} else {
|
||||||
|
@ -502,9 +502,9 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var dstTableName = dstTable.Name
|
dstTableName := dstTable.Name
|
||||||
var quoter = dstDialect.Quoter().Quote
|
quoter := dstDialect.Quoter().Quote
|
||||||
var quotedDstTableName = quoter(dstTable.Name)
|
quotedDstTableName := quoter(dstTable.Name)
|
||||||
if dstDialect.URI().Schema != "" {
|
if dstDialect.URI().Schema != "" {
|
||||||
dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
|
dstTableName = fmt.Sprintf("%s.%s", dstDialect.URI().Schema, dstTable.Name)
|
||||||
quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name))
|
quotedDstTableName = fmt.Sprintf("%s.%s", quoter(dstDialect.URI().Schema), quoter(dstTable.Name))
|
||||||
|
@ -1006,10 +1006,10 @@ func (engine *Engine) Asc(colNames ...string) *Session {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OrderBy will generate "ORDER BY order"
|
// OrderBy will generate "ORDER BY order"
|
||||||
func (engine *Engine) OrderBy(order string) *Session {
|
func (engine *Engine) OrderBy(order string, args ...interface{}) *Session {
|
||||||
session := engine.NewSession()
|
session := engine.NewSession()
|
||||||
session.isAutoClose = true
|
session.isAutoClose = true
|
||||||
return session.OrderBy(order)
|
return session.OrderBy(order, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare enables prepare statement
|
// Prepare enables prepare statement
|
||||||
|
|
|
@ -247,6 +247,10 @@ func TestOrder(t *testing.T) {
|
||||||
users2 := make([]Userinfo, 0)
|
users2 := make([]Userinfo, 0)
|
||||||
err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
|
err = testEngine.Asc("id", "username").Desc("height").Find(&users2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
users = make([]Userinfo, 0)
|
||||||
|
err = testEngine.OrderBy("case username like ? desc", "a").Find(&users)
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGroupBy(t *testing.T) {
|
func TestGroupBy(t *testing.T) {
|
||||||
|
|
|
@ -54,7 +54,7 @@ type Interface interface {
|
||||||
Nullable(...string) *Session
|
Nullable(...string) *Session
|
||||||
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
|
Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session
|
||||||
Omit(columns ...string) *Session
|
Omit(columns ...string) *Session
|
||||||
OrderBy(order string) *Session
|
OrderBy(order string, args ...interface{}) *Session
|
||||||
Ping() error
|
Ping() error
|
||||||
Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error)
|
Query(sqlOrArgs ...interface{}) (resultsSlice []map[string][]byte, err error)
|
||||||
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)
|
QueryInterface(sqlOrArgs ...interface{}) ([]map[string]interface{}, error)
|
||||||
|
|
|
@ -28,7 +28,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
|
||||||
return "", nil, ErrTableNotFound
|
return "", nil, ErrTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var columnStr = statement.ColumnStr()
|
columnStr := statement.ColumnStr()
|
||||||
if len(statement.SelectStr) > 0 {
|
if len(statement.SelectStr) > 0 {
|
||||||
columnStr = statement.SelectStr
|
columnStr = statement.SelectStr
|
||||||
} else {
|
} else {
|
||||||
|
@ -83,7 +83,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var sumStrs = make([]string, 0, len(columns))
|
sumStrs := make([]string, 0, len(columns))
|
||||||
for _, colName := range columns {
|
for _, colName := range columns {
|
||||||
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
|
||||||
colName = statement.quote(colName)
|
colName = statement.quote(colName)
|
||||||
|
@ -94,7 +94,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
|
||||||
}
|
}
|
||||||
sumSelect := strings.Join(sumStrs, ", ")
|
sumSelect := strings.Join(sumStrs, ", ")
|
||||||
|
|
||||||
if err := statement.mergeConds(bean); err != nil {
|
if err := statement.MergeConds(bean); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var columnStr = statement.ColumnStr()
|
columnStr := statement.ColumnStr()
|
||||||
if len(statement.SelectStr) > 0 {
|
if len(statement.SelectStr) > 0 {
|
||||||
columnStr = statement.SelectStr
|
columnStr = statement.SelectStr
|
||||||
} else {
|
} else {
|
||||||
|
@ -146,7 +146,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
|
||||||
}
|
}
|
||||||
|
|
||||||
if isStruct {
|
if isStruct {
|
||||||
if err := statement.mergeConds(bean); err != nil {
|
if err := statement.MergeConds(bean); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -175,12 +175,12 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
if err := statement.SetRefBean(beans[0]); err != nil {
|
if err := statement.SetRefBean(beans[0]); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if err := statement.mergeConds(beans[0]); err != nil {
|
if err := statement.MergeConds(beans[0]); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var selectSQL = statement.SelectStr
|
selectSQL := statement.SelectStr
|
||||||
if len(selectSQL) <= 0 {
|
if len(selectSQL) <= 0 {
|
||||||
if statement.IsDistinct {
|
if statement.IsDistinct {
|
||||||
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
|
selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr())
|
||||||
|
@ -211,8 +211,8 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
|
||||||
|
|
||||||
func (statement *Statement) fromBuilder() *strings.Builder {
|
func (statement *Statement) fromBuilder() *strings.Builder {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
var quote = statement.quote
|
quote := statement.quote
|
||||||
var dialect = statement.dialect
|
dialect := statement.dialect
|
||||||
|
|
||||||
builder.WriteString(" FROM ")
|
builder.WriteString(" FROM ")
|
||||||
|
|
||||||
|
@ -242,7 +242,8 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
|
||||||
distinct string
|
distinct string
|
||||||
dialect = statement.dialect
|
dialect = statement.dialect
|
||||||
fromStr = statement.fromBuilder().String()
|
fromStr = statement.fromBuilder().String()
|
||||||
top, mssqlCondi, whereStr string
|
top, whereStr string
|
||||||
|
mssqlCondi = builder.NewWriter()
|
||||||
)
|
)
|
||||||
|
|
||||||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
||||||
|
@ -289,49 +290,59 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var orderStr string
|
orderByWriter := builder.NewWriter()
|
||||||
if needOrderBy && len(statement.OrderStr) > 0 {
|
if needOrderBy {
|
||||||
orderStr = fmt.Sprintf(" ORDER BY %s", statement.OrderStr)
|
if err := statement.WriteOrderBy(orderByWriter); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var groupStr string
|
var groupStr string
|
||||||
if len(statement.GroupByStr) > 0 {
|
if len(statement.GroupByStr) > 0 {
|
||||||
groupStr = fmt.Sprintf(" GROUP BY %s", statement.GroupByStr)
|
groupStr = fmt.Sprintf(" GROUP BY %s", statement.GroupByStr)
|
||||||
}
|
}
|
||||||
mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
|
|
||||||
column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
|
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
|
||||||
|
column, statement.Start, column, fromStr, whereStr, orderByWriter.String(), groupStr); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
mssqlCondi.Append(orderByWriter.Args()...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf strings.Builder
|
buf := builder.NewWriter()
|
||||||
fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
|
fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
|
||||||
if len(mssqlCondi) > 0 {
|
if mssqlCondi.Len() > 0 {
|
||||||
if len(whereStr) > 0 {
|
if len(whereStr) > 0 {
|
||||||
fmt.Fprint(&buf, " AND ", mssqlCondi)
|
fmt.Fprint(buf, " AND ")
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprint(&buf, " WHERE ", mssqlCondi)
|
fmt.Fprint(buf, " WHERE ")
|
||||||
}
|
}
|
||||||
|
fmt.Fprint(buf, mssqlCondi.String())
|
||||||
|
buf.Append(mssqlCondi.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if statement.GroupByStr != "" {
|
if statement.GroupByStr != "" {
|
||||||
fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
|
fmt.Fprint(buf, " GROUP BY ", statement.GroupByStr)
|
||||||
}
|
}
|
||||||
if statement.HavingStr != "" {
|
if statement.HavingStr != "" {
|
||||||
fmt.Fprint(&buf, " ", statement.HavingStr)
|
fmt.Fprint(buf, " ", statement.HavingStr)
|
||||||
|
}
|
||||||
|
if needOrderBy {
|
||||||
|
if err := statement.WriteOrderBy(buf); err != nil {
|
||||||
|
return "", nil, err
|
||||||
}
|
}
|
||||||
if needOrderBy && statement.OrderStr != "" {
|
|
||||||
fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
|
|
||||||
}
|
}
|
||||||
if needLimit {
|
if needLimit {
|
||||||
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
|
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
|
||||||
if statement.Start > 0 {
|
if statement.Start > 0 {
|
||||||
if pLimitN != nil {
|
if pLimitN != nil {
|
||||||
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
|
fmt.Fprintf(buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(&buf, " LIMIT 0 OFFSET %v", statement.Start)
|
fmt.Fprintf(buf, " LIMIT 0 OFFSET %v", statement.Start)
|
||||||
}
|
}
|
||||||
} else if pLimitN != nil {
|
} else if pLimitN != nil {
|
||||||
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
|
fmt.Fprint(buf, " LIMIT ", *pLimitN)
|
||||||
}
|
}
|
||||||
} else if dialect.URI().DBType == schemas.ORACLE {
|
} else if dialect.URI().DBType == schemas.ORACLE {
|
||||||
if pLimitN != nil {
|
if pLimitN != nil {
|
||||||
|
@ -341,7 +352,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
|
||||||
if rawColStr == "*" {
|
if rawColStr == "*" {
|
||||||
rawColStr = "at.*"
|
rawColStr = "at.*"
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
|
fmt.Fprintf(buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
|
||||||
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
|
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -436,7 +447,7 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
|
||||||
return "", nil, ErrTableNotFound
|
return "", nil, ErrTableNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
var columnStr = statement.ColumnStr()
|
columnStr := statement.ColumnStr()
|
||||||
if len(statement.SelectStr) > 0 {
|
if len(statement.SelectStr) > 0 {
|
||||||
columnStr = statement.SelectStr
|
columnStr = statement.SelectStr
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -44,6 +44,7 @@ type Statement struct {
|
||||||
LimitN *int
|
LimitN *int
|
||||||
idParam schemas.PK
|
idParam schemas.PK
|
||||||
OrderStr string
|
OrderStr string
|
||||||
|
orderArgs []interface{}
|
||||||
JoinStr string
|
JoinStr string
|
||||||
joinArgs []interface{}
|
joinArgs []interface{}
|
||||||
GroupByStr string
|
GroupByStr string
|
||||||
|
@ -129,7 +130,7 @@ func (statement *Statement) Reset() {
|
||||||
statement.RefTable = nil
|
statement.RefTable = nil
|
||||||
statement.Start = 0
|
statement.Start = 0
|
||||||
statement.LimitN = nil
|
statement.LimitN = nil
|
||||||
statement.OrderStr = ""
|
statement.ResetOrderBy()
|
||||||
statement.UseCascade = true
|
statement.UseCascade = true
|
||||||
statement.JoinStr = ""
|
statement.JoinStr = ""
|
||||||
statement.joinArgs = make([]interface{}, 0)
|
statement.joinArgs = make([]interface{}, 0)
|
||||||
|
@ -454,12 +455,32 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement {
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResetOrderBy reset ordery conditions
|
||||||
|
func (statement *Statement) ResetOrderBy() {
|
||||||
|
statement.OrderStr = ""
|
||||||
|
statement.orderArgs = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteOrderBy write order by to writer
|
||||||
|
func (statement *Statement) WriteOrderBy(w *builder.BytesWriter) error {
|
||||||
|
if len(statement.OrderStr) > 0 {
|
||||||
|
if _, err := fmt.Fprintf(w, " ORDER BY %s", statement.OrderStr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(statement.orderArgs...)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// OrderBy generate "Order By order" statement
|
// OrderBy generate "Order By order" statement
|
||||||
func (statement *Statement) OrderBy(order string) *Statement {
|
func (statement *Statement) OrderBy(order string, args ...interface{}) *Statement {
|
||||||
if len(statement.OrderStr) > 0 {
|
if len(statement.OrderStr) > 0 {
|
||||||
statement.OrderStr += ", "
|
statement.OrderStr += ", "
|
||||||
}
|
}
|
||||||
statement.OrderStr += statement.ReplaceQuote(order)
|
statement.OrderStr += statement.ReplaceQuote(order)
|
||||||
|
if len(args) > 0 {
|
||||||
|
statement.orderArgs = append(statement.orderArgs, args...)
|
||||||
|
}
|
||||||
return statement
|
return statement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -914,7 +935,8 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i
|
||||||
statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
|
statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) mergeConds(bean interface{}) error {
|
// MergeConds merge conditions from bean and id
|
||||||
|
func (statement *Statement) MergeConds(bean interface{}) error {
|
||||||
if !statement.NoAutoCondition && statement.RefTable != nil {
|
if !statement.NoAutoCondition && statement.RefTable != nil {
|
||||||
addedTableName := (len(statement.JoinStr) > 0)
|
addedTableName := (len(statement.JoinStr) > 0)
|
||||||
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
|
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
|
||||||
|
@ -927,15 +949,6 @@ func (statement *Statement) mergeConds(bean interface{}) error {
|
||||||
return statement.ProcessIDParam()
|
return statement.ProcessIDParam()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenConds generates conditions
|
|
||||||
func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) {
|
|
||||||
if err := statement.mergeConds(bean); err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return statement.GenCondSQL(statement.cond)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
func (statement *Statement) quoteColumnStr(columnStr string) string {
|
||||||
columns := strings.Split(columnStr, ",")
|
columns := strings.Split(columnStr, ",")
|
||||||
return statement.dialect.Quoter().Join(columns, ",")
|
return statement.dialect.Quoter().Join(columns, ",")
|
||||||
|
|
|
@ -275,8 +275,8 @@ func (session *Session) Limit(limit int, start ...int) *Session {
|
||||||
|
|
||||||
// OrderBy provide order by query condition, the input parameter is the content
|
// OrderBy provide order by query condition, the input parameter is the content
|
||||||
// after order by on a sql statement.
|
// after order by on a sql statement.
|
||||||
func (session *Session) OrderBy(order string) *Session {
|
func (session *Session) OrderBy(order string, args ...interface{}) *Session {
|
||||||
session.statement.OrderBy(order)
|
session.statement.OrderBy(order, args...)
|
||||||
return session
|
return session
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"xorm.io/builder"
|
||||||
"xorm.io/xorm/caches"
|
"xorm.io/xorm/caches"
|
||||||
"xorm.io/xorm/schemas"
|
"xorm.io/xorm/schemas"
|
||||||
)
|
)
|
||||||
|
@ -88,6 +89,16 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeBuilder(w *builder.BytesWriter, inputs ...*builder.BytesWriter) error {
|
||||||
|
for _, input := range inputs {
|
||||||
|
if _, err := fmt.Fprint(w, input.String()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(input.Args()...)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete records, bean's non-empty fields are conditions
|
// Delete records, bean's non-empty fields are conditions
|
||||||
func (session *Session) Delete(beans ...interface{}) (int64, error) {
|
func (session *Session) Delete(beans ...interface{}) (int64, error) {
|
||||||
if session.isAutoClose {
|
if session.isAutoClose {
|
||||||
|
@ -99,8 +110,7 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
condSQL string
|
condWriter = builder.NewWriter()
|
||||||
condArgs []interface{}
|
|
||||||
err error
|
err error
|
||||||
bean interface{}
|
bean interface{}
|
||||||
)
|
)
|
||||||
|
@ -116,115 +126,97 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
|
||||||
processor.BeforeDelete()
|
processor.BeforeDelete()
|
||||||
}
|
}
|
||||||
|
|
||||||
condSQL, condArgs, err = session.statement.GenConds(bean)
|
if err = session.statement.MergeConds(bean); err != nil {
|
||||||
} else {
|
return 0, err
|
||||||
condSQL, condArgs, err = session.statement.GenCondSQL(session.statement.Conds())
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
}
|
||||||
|
|
||||||
|
if err = session.statement.Conds().WriteTo(condWriter); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pLimitN := session.statement.LimitN
|
pLimitN := session.statement.LimitN
|
||||||
if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) {
|
if condWriter.Len() == 0 && (pLimitN == nil || *pLimitN == 0) {
|
||||||
return 0, ErrNeedDeletedCond
|
return 0, ErrNeedDeletedCond
|
||||||
}
|
}
|
||||||
|
|
||||||
var tableNameNoQuote = session.statement.TableName()
|
tableNameNoQuote := session.statement.TableName()
|
||||||
var tableName = session.engine.Quote(tableNameNoQuote)
|
tableName := session.engine.Quote(tableNameNoQuote)
|
||||||
var table = session.statement.RefTable
|
table := session.statement.RefTable
|
||||||
var deleteSQL string
|
deleteSQLWriter := builder.NewWriter()
|
||||||
if len(condSQL) > 0 {
|
fmt.Fprintf(deleteSQLWriter, "DELETE FROM %v", tableName)
|
||||||
deleteSQL = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condSQL)
|
if condWriter.Len() > 0 {
|
||||||
} else {
|
fmt.Fprintf(deleteSQLWriter, " WHERE %v", condWriter.String())
|
||||||
deleteSQL = fmt.Sprintf("DELETE FROM %v", tableName)
|
deleteSQLWriter.Append(condWriter.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
var orderSQL string
|
orderSQLWriter := builder.NewWriter()
|
||||||
if len(session.statement.OrderStr) > 0 {
|
if err := session.statement.WriteOrderBy(orderSQLWriter); err != nil {
|
||||||
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if pLimitN != nil && *pLimitN > 0 {
|
if pLimitN != nil && *pLimitN > 0 {
|
||||||
limitNValue := *pLimitN
|
limitNValue := *pLimitN
|
||||||
orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue)
|
if _, err := fmt.Fprintf(orderSQLWriter, " LIMIT %d", limitNValue); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(orderSQL) > 0 {
|
orderCondWriter := builder.NewWriter()
|
||||||
|
if orderSQLWriter.Len() > 0 {
|
||||||
switch session.engine.dialect.URI().DBType {
|
switch session.engine.dialect.URI().DBType {
|
||||||
case schemas.POSTGRES:
|
case schemas.POSTGRES:
|
||||||
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
|
if condWriter.Len() > 0 {
|
||||||
if len(condSQL) > 0 {
|
fmt.Fprintf(orderCondWriter, " AND ")
|
||||||
deleteSQL += " AND " + inSQL
|
|
||||||
} else {
|
} else {
|
||||||
deleteSQL += " WHERE " + inSQL
|
fmt.Fprintf(orderCondWriter, " WHERE ")
|
||||||
}
|
}
|
||||||
|
fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String())
|
||||||
|
orderCondWriter.Append(orderSQLWriter.Args()...)
|
||||||
case schemas.SQLITE:
|
case schemas.SQLITE:
|
||||||
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
|
if condWriter.Len() > 0 {
|
||||||
if len(condSQL) > 0 {
|
fmt.Fprintf(orderCondWriter, " AND ")
|
||||||
deleteSQL += " AND " + inSQL
|
|
||||||
} else {
|
} else {
|
||||||
deleteSQL += " WHERE " + inSQL
|
fmt.Fprintf(orderCondWriter, " WHERE ")
|
||||||
}
|
}
|
||||||
|
fmt.Fprintf(orderCondWriter, "rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQLWriter.String())
|
||||||
// TODO: how to handle delete limit on mssql?
|
// TODO: how to handle delete limit on mssql?
|
||||||
case schemas.MSSQL:
|
case schemas.MSSQL:
|
||||||
return 0, ErrNotImplemented
|
return 0, ErrNotImplemented
|
||||||
default:
|
default:
|
||||||
deleteSQL += orderSQL
|
fmt.Fprint(orderCondWriter, orderSQLWriter.String())
|
||||||
|
orderCondWriter.Append(orderSQLWriter.Args()...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var realSQL string
|
realSQLWriter := builder.NewWriter()
|
||||||
argsForCache := make([]interface{}, 0, len(condArgs)*2)
|
argsForCache := make([]interface{}, 0, len(deleteSQLWriter.Args())*2)
|
||||||
|
copy(argsForCache, deleteSQLWriter.Args())
|
||||||
|
argsForCache = append(deleteSQLWriter.Args(), argsForCache...)
|
||||||
if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled
|
if session.statement.GetUnscoped() || table == nil || table.DeletedColumn() == nil { // tag "deleted" is disabled
|
||||||
realSQL = deleteSQL
|
if err := writeBuilder(realSQLWriter, deleteSQLWriter, orderCondWriter); err != nil {
|
||||||
copy(argsForCache, condArgs)
|
return 0, err
|
||||||
argsForCache = append(condArgs, argsForCache...)
|
}
|
||||||
} else {
|
} else {
|
||||||
// !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for caches.
|
|
||||||
copy(argsForCache, condArgs)
|
|
||||||
argsForCache = append(condArgs, argsForCache...)
|
|
||||||
|
|
||||||
deletedColumn := table.DeletedColumn()
|
deletedColumn := table.DeletedColumn()
|
||||||
realSQL = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v",
|
if _, err := fmt.Fprintf(realSQLWriter, "UPDATE %v SET %v = ? WHERE %v",
|
||||||
session.engine.Quote(session.statement.TableName()),
|
session.engine.Quote(session.statement.TableName()),
|
||||||
session.engine.Quote(deletedColumn.Name),
|
session.engine.Quote(deletedColumn.Name),
|
||||||
condSQL)
|
condWriter.String()); err != nil {
|
||||||
|
return 0, err
|
||||||
if len(orderSQL) > 0 {
|
|
||||||
switch session.engine.dialect.URI().DBType {
|
|
||||||
case schemas.POSTGRES:
|
|
||||||
inSQL := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQL)
|
|
||||||
if len(condSQL) > 0 {
|
|
||||||
realSQL += " AND " + inSQL
|
|
||||||
} else {
|
|
||||||
realSQL += " WHERE " + inSQL
|
|
||||||
}
|
}
|
||||||
case schemas.SQLITE:
|
|
||||||
inSQL := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSQL)
|
|
||||||
if len(condSQL) > 0 {
|
|
||||||
realSQL += " AND " + inSQL
|
|
||||||
} else {
|
|
||||||
realSQL += " WHERE " + inSQL
|
|
||||||
}
|
|
||||||
// TODO: how to handle delete limit on mssql?
|
|
||||||
case schemas.MSSQL:
|
|
||||||
return 0, ErrNotImplemented
|
|
||||||
default:
|
|
||||||
realSQL += orderSQL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// !oinume! Insert nowTime to the head of session.statement.Params
|
|
||||||
condArgs = append(condArgs, "")
|
|
||||||
paramsLen := len(condArgs)
|
|
||||||
copy(condArgs[1:paramsLen], condArgs[0:paramsLen-1])
|
|
||||||
|
|
||||||
val, t, err := session.engine.nowTime(deletedColumn)
|
val, t, err := session.engine.nowTime(deletedColumn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
condArgs[0] = val
|
realSQLWriter.Append(val)
|
||||||
|
realSQLWriter.Append(condWriter.Args()...)
|
||||||
|
|
||||||
var colName = deletedColumn.Name
|
if err := writeBuilder(realSQLWriter, orderCondWriter); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
colName := deletedColumn.Name
|
||||||
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
|
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
|
||||||
col := table.GetColumn(colName)
|
col := table.GetColumn(colName)
|
||||||
setColumnTime(bean, col, t)
|
setColumnTime(bean, col, t)
|
||||||
|
@ -232,11 +224,11 @@ func (session *Session) Delete(beans ...interface{}) (int64, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache {
|
if cacher := session.engine.GetCacher(tableNameNoQuote); cacher != nil && session.statement.UseCache {
|
||||||
_ = session.cacheDelete(table, tableNameNoQuote, deleteSQL, argsForCache...)
|
_ = session.cacheDelete(table, tableNameNoQuote, deleteSQLWriter.String(), argsForCache...)
|
||||||
}
|
}
|
||||||
|
|
||||||
session.statement.RefTable = table
|
session.statement.RefTable = table
|
||||||
res, err := session.exec(realSQL, condArgs...)
|
res, err := session.exec(realSQLWriter.String(), realSQLWriter.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,9 +60,7 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte
|
||||||
if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct {
|
if len(session.statement.ColumnMap) > 0 && !session.statement.IsDistinct {
|
||||||
session.statement.ColumnMap = []string{}
|
session.statement.ColumnMap = []string{}
|
||||||
}
|
}
|
||||||
if session.statement.OrderStr != "" {
|
session.statement.ResetOrderBy()
|
||||||
session.statement.OrderStr = ""
|
|
||||||
}
|
|
||||||
if session.statement.LimitN != nil {
|
if session.statement.LimitN != nil {
|
||||||
session.statement.LimitN = nil
|
session.statement.LimitN = nil
|
||||||
}
|
}
|
||||||
|
@ -85,15 +83,15 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
|
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
|
||||||
var isSlice = sliceValue.Kind() == reflect.Slice
|
isSlice := sliceValue.Kind() == reflect.Slice
|
||||||
var isMap = sliceValue.Kind() == reflect.Map
|
isMap := sliceValue.Kind() == reflect.Map
|
||||||
if !isSlice && !isMap {
|
if !isSlice && !isMap {
|
||||||
return errors.New("needs a pointer to a slice or a map")
|
return errors.New("needs a pointer to a slice or a map")
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceElementType := sliceValue.Type().Elem()
|
sliceElementType := sliceValue.Type().Elem()
|
||||||
|
|
||||||
var tp = tpStruct
|
tp := tpStruct
|
||||||
if session.statement.RefTable == nil {
|
if session.statement.RefTable == nil {
|
||||||
if sliceElementType.Kind() == reflect.Ptr {
|
if sliceElementType.Kind() == reflect.Ptr {
|
||||||
if sliceElementType.Elem().Kind() == reflect.Struct {
|
if sliceElementType.Elem().Kind() == reflect.Struct {
|
||||||
|
@ -190,7 +188,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var newElemFunc = func(fields []string) reflect.Value {
|
newElemFunc := func(fields []string) reflect.Value {
|
||||||
return utils.New(elemType, len(fields), len(fields))
|
return utils.New(elemType, len(fields), len(fields))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,7 +233,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
||||||
}
|
}
|
||||||
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
if elemType.Kind() == reflect.Struct {
|
||||||
var newValue = newElemFunc(fields)
|
newValue := newElemFunc(fields)
|
||||||
tb, err := session.engine.tagParser.ParseWithCache(newValue)
|
tb, err := session.engine.tagParser.ParseWithCache(newValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -249,7 +247,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect
|
||||||
}
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var newValue = newElemFunc(fields)
|
newValue := newElemFunc(fields)
|
||||||
bean := newValue.Interface()
|
bean := newValue.Interface()
|
||||||
|
|
||||||
switch elemType.Kind() {
|
switch elemType.Kind() {
|
||||||
|
@ -310,7 +308,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
|
||||||
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")
|
session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache")
|
||||||
return ErrCacheFailed
|
return ErrCacheFailed
|
||||||
}
|
}
|
||||||
var res = make([]string, len(table.PrimaryKeys))
|
res := make([]string, len(table.PrimaryKeys))
|
||||||
err = rows.ScanSlice(&res)
|
err = rows.ScanSlice(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -342,7 +340,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
|
||||||
|
|
||||||
ididxes := make(map[string]int)
|
ididxes := make(map[string]int)
|
||||||
var ides []schemas.PK
|
var ides []schemas.PK
|
||||||
var temps = make([]interface{}, len(ids))
|
temps := make([]interface{}, len(ids))
|
||||||
|
|
||||||
for idx, id := range ids {
|
for idx, id := range ids {
|
||||||
sid, err := id.ToString()
|
sid, err := id.ToString()
|
||||||
|
@ -457,7 +455,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
|
||||||
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
|
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
|
||||||
}
|
}
|
||||||
} else if sliceValue.Kind() == reflect.Map {
|
} else if sliceValue.Kind() == reflect.Map {
|
||||||
var key = ids[j]
|
key := ids[j]
|
||||||
keyType := sliceValue.Type().Key()
|
keyType := sliceValue.Type().Key()
|
||||||
keyValue := reflect.New(keyType)
|
keyValue := reflect.New(keyType)
|
||||||
var ikey interface{}
|
var ikey interface{}
|
||||||
|
|
Loading…
Reference in New Issue