fix bug and catch more tests (#613)
This commit is contained in:
parent
7a9bf19c65
commit
1e055bac01
9
rows.go
9
rows.go
|
@ -33,8 +33,9 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
if err := rows.session.Statement.setRefValue(rValue(bean)); err != nil {
|
if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +44,10 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.session.Statement.RawSQL == "" {
|
if rows.session.Statement.RawSQL == "" {
|
||||||
sqlStr, args = rows.session.Statement.genGetSQL(bean)
|
sqlStr, args, err = rows.session.Statement.genGetSQL(bean)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = rows.session.Statement.RawSQL
|
sqlStr = rows.session.Statement.RawSQL
|
||||||
args = rows.session.Statement.RawParams
|
args = rows.session.Statement.RawParams
|
||||||
|
@ -54,7 +58,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
rows.session.saveLastSQL(sqlStr, args...)
|
rows.session.saveLastSQL(sqlStr, args...)
|
||||||
var err error
|
|
||||||
if rows.session.prepareStmt {
|
if rows.session.prepareStmt {
|
||||||
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
|
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -48,6 +48,8 @@ type Session struct {
|
||||||
//beforeSQLExec func(string, ...interface{})
|
//beforeSQLExec func(string, ...interface{})
|
||||||
lastSQL string
|
lastSQL string
|
||||||
lastSQLArgs []interface{}
|
lastSQLArgs []interface{}
|
||||||
|
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone copy all the session's content and return a new session
|
// Clone copy all the session's content and return a new session
|
||||||
|
@ -620,7 +622,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
|
||||||
structInter := reflect.New(fieldValue.Type())
|
structInter := reflect.New(fieldValue.Type())
|
||||||
newsession := session.Engine.NewSession()
|
newsession := session.Engine.NewSession()
|
||||||
defer newsession.Close()
|
defer newsession.Close()
|
||||||
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
|
has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,8 +98,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
||||||
processor.BeforeDelete()
|
processor.BeforeDelete()
|
||||||
}
|
}
|
||||||
|
|
||||||
// --
|
condSQL, condArgs, err := session.Statement.genConds(bean)
|
||||||
condSQL, condArgs, _ := session.Statement.genConds(bean)
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if len(condSQL) == 0 && session.Statement.LimitN == 0 {
|
if len(condSQL) == 0 && session.Statement.LimitN == 0 {
|
||||||
return 0, ErrNeedDeletedCond
|
return 0, ErrNeedDeletedCond
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ func TestDelete(t *testing.T) {
|
||||||
assert.NoError(t, prepareEngine())
|
assert.NoError(t, prepareEngine())
|
||||||
|
|
||||||
type UserinfoDelete struct {
|
type UserinfoDelete struct {
|
||||||
Uid int64
|
Uid int64 `xorm:"id pk not null autoincr"`
|
||||||
IsMan bool
|
IsMan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
if session.Statement.RawSQL == "" {
|
if session.Statement.RawSQL == "" {
|
||||||
if len(session.Statement.TableName()) <= 0 {
|
if len(session.Statement.TableName()) <= 0 {
|
||||||
return ErrTableNotFound
|
return ErrTableNotFound
|
||||||
|
@ -128,7 +129,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
}
|
}
|
||||||
|
|
||||||
args = append(session.Statement.joinArgs, condArgs...)
|
args = append(session.Statement.joinArgs, condArgs...)
|
||||||
sqlStr = session.Statement.genSelectSQL(columnStr, condSQL)
|
sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// for mssql and use limit
|
// for mssql and use limit
|
||||||
qs := strings.Count(sqlStr, "?")
|
qs := strings.Count(sqlStr, "?")
|
||||||
if len(args)*2 == qs {
|
if len(args)*2 == qs {
|
||||||
|
@ -139,7 +143,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
||||||
args = session.Statement.RawParams
|
args = session.Statement.RawParams
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
if session.canCache() {
|
if session.canCache() {
|
||||||
if cacher := session.Engine.getCacher2(table); cacher != nil &&
|
if cacher := session.Engine.getCacher2(table); cacher != nil &&
|
||||||
!session.Statement.IsDistinct &&
|
!session.Statement.IsDistinct &&
|
||||||
|
|
|
@ -33,13 +33,17 @@ func (session *Session) Get(bean interface{}) (bool, error) {
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
if session.Statement.RawSQL == "" {
|
if session.Statement.RawSQL == "" {
|
||||||
if len(session.Statement.TableName()) <= 0 {
|
if len(session.Statement.TableName()) <= 0 {
|
||||||
return false, ErrTableNotFound
|
return false, ErrTableNotFound
|
||||||
}
|
}
|
||||||
session.Statement.Limit(1)
|
session.Statement.Limit(1)
|
||||||
sqlStr, args = session.Statement.genGetSQL(bean)
|
sqlStr, args, err = session.Statement.genGetSQL(bean)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.Statement.RawSQL
|
sqlStr = session.Statement.RawSQL
|
||||||
args = session.Statement.RawParams
|
args = session.Statement.RawParams
|
||||||
|
|
|
@ -1139,3 +1139,29 @@ func TestCompositePK(t *testing.T) {
|
||||||
assert.EqualValues(t, "uid", pkCols[0].Name)
|
assert.EqualValues(t, "uid", pkCols[0].Name)
|
||||||
assert.EqualValues(t, "tid", pkCols[1].Name)
|
assert.EqualValues(t, "tid", pkCols[1].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNoPKIdQueryUpdate(t *testing.T) {
|
||||||
|
type NoPKTable struct {
|
||||||
|
Username string
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
assertSync(t, new(NoPKTable))
|
||||||
|
|
||||||
|
cnt, err := testEngine.Insert(&NoPKTable{
|
||||||
|
Username: "test",
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, cnt)
|
||||||
|
|
||||||
|
var res NoPKTable
|
||||||
|
has, err := testEngine.ID("test").Get(&res)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.False(t, has)
|
||||||
|
|
||||||
|
cnt, err = testEngine.ID("test").Update(&NoPKTable{
|
||||||
|
Username: "test1",
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.EqualValues(t, 0, cnt)
|
||||||
|
}
|
||||||
|
|
|
@ -16,11 +16,15 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
if session.Statement.RawSQL == "" {
|
if session.Statement.RawSQL == "" {
|
||||||
if len(bean) == 0 {
|
if len(bean) == 0 {
|
||||||
return 0, ErrTableNotFound
|
return 0, ErrTableNotFound
|
||||||
}
|
}
|
||||||
sqlStr, args = session.Statement.genCountSQL(bean[0])
|
sqlStr, args, err = session.Statement.genCountSQL(bean[0])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.Statement.RawSQL
|
sqlStr = session.Statement.RawSQL
|
||||||
args = session.Statement.RawParams
|
args = session.Statement.RawParams
|
||||||
|
@ -28,7 +32,6 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
|
||||||
|
|
||||||
session.queryPreprocess(&sqlStr, args...)
|
session.queryPreprocess(&sqlStr, args...)
|
||||||
|
|
||||||
var err error
|
|
||||||
var total int64
|
var total int64
|
||||||
if session.IsAutoCommit {
|
if session.IsAutoCommit {
|
||||||
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
|
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
|
||||||
|
@ -52,8 +55,12 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
if len(session.Statement.RawSQL) == 0 {
|
if len(session.Statement.RawSQL) == 0 {
|
||||||
sqlStr, args = session.Statement.genSumSQL(bean, columnName)
|
sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.Statement.RawSQL
|
sqlStr = session.Statement.RawSQL
|
||||||
args = session.Statement.RawParams
|
args = session.Statement.RawParams
|
||||||
|
@ -61,7 +68,6 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
|
||||||
|
|
||||||
session.queryPreprocess(&sqlStr, args...)
|
session.queryPreprocess(&sqlStr, args...)
|
||||||
|
|
||||||
var err error
|
|
||||||
var res float64
|
var res float64
|
||||||
if session.IsAutoCommit {
|
if session.IsAutoCommit {
|
||||||
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
|
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
|
||||||
|
@ -84,8 +90,12 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
if len(session.Statement.RawSQL) == 0 {
|
if len(session.Statement.RawSQL) == 0 {
|
||||||
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
|
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.Statement.RawSQL
|
sqlStr = session.Statement.RawSQL
|
||||||
args = session.Statement.RawParams
|
args = session.Statement.RawParams
|
||||||
|
@ -93,7 +103,6 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
|
||||||
|
|
||||||
session.queryPreprocess(&sqlStr, args...)
|
session.queryPreprocess(&sqlStr, args...)
|
||||||
|
|
||||||
var err error
|
|
||||||
var res = make([]float64, len(columnNames), len(columnNames))
|
var res = make([]float64, len(columnNames), len(columnNames))
|
||||||
if session.IsAutoCommit {
|
if session.IsAutoCommit {
|
||||||
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
|
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
|
||||||
|
@ -116,8 +125,12 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
|
||||||
|
|
||||||
var sqlStr string
|
var sqlStr string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
|
var err error
|
||||||
if len(session.Statement.RawSQL) == 0 {
|
if len(session.Statement.RawSQL) == 0 {
|
||||||
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
|
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
sqlStr = session.Statement.RawSQL
|
sqlStr = session.Statement.RawSQL
|
||||||
args = session.Statement.RawParams
|
args = session.Statement.RawParams
|
||||||
|
@ -125,7 +138,6 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
|
||||||
|
|
||||||
session.queryPreprocess(&sqlStr, args...)
|
session.queryPreprocess(&sqlStr, args...)
|
||||||
|
|
||||||
var err error
|
|
||||||
var res = make([]int64, len(columnNames), len(columnNames))
|
var res = make([]int64, len(columnNames), len(columnNames))
|
||||||
if session.IsAutoCommit {
|
if session.IsAutoCommit {
|
||||||
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
|
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
|
||||||
|
|
|
@ -236,7 +236,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
|
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Statement.processIDParam()
|
if err = session.Statement.processIDParam(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
var autoCond builder.Cond
|
var autoCond builder.Cond
|
||||||
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
|
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
|
||||||
|
@ -267,7 +269,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
|
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
|
@ -285,7 +291,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||||
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
|
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
|
||||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
|
@ -293,7 +302,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||||
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
|
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
|
||||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
|
@ -304,7 +317,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
|
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
|
||||||
session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...)
|
session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...)
|
||||||
|
|
||||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
|
|
71
statement.go
71
statement.go
|
@ -1118,12 +1118,14 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
|
||||||
statement.cond = statement.cond.And(autoCond)
|
statement.cond = statement.cond.And(autoCond)
|
||||||
}
|
}
|
||||||
|
|
||||||
statement.processIDParam()
|
if err := statement.processIDParam(); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return builder.ToSQL(statement.cond)
|
return builder.ToSQL(statement.cond)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
|
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
|
||||||
v := rValue(bean)
|
v := rValue(bean)
|
||||||
isStruct := v.Kind() == reflect.Struct
|
isStruct := v.Kind() == reflect.Struct
|
||||||
if isStruct {
|
if isStruct {
|
||||||
|
@ -1158,19 +1160,31 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
|
||||||
|
|
||||||
var condSQL string
|
var condSQL string
|
||||||
var condArgs []interface{}
|
var condArgs []interface{}
|
||||||
|
var err error
|
||||||
if isStruct {
|
if isStruct {
|
||||||
condSQL, condArgs, _ = statement.genConds(bean)
|
condSQL, condArgs, err = statement.genConds(bean)
|
||||||
} else {
|
} else {
|
||||||
condSQL, condArgs, _ = builder.ToSQL(statement.cond)
|
condSQL, condArgs, err = builder.ToSQL(statement.cond)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
|
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
|
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) {
|
||||||
statement.setRefValue(rValue(bean))
|
statement.setRefValue(rValue(bean))
|
||||||
|
|
||||||
condSQL, condArgs, _ := statement.genConds(bean)
|
condSQL, condArgs, err := statement.genConds(bean)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
var selectSQL = statement.selectStr
|
var selectSQL = statement.selectStr
|
||||||
if len(selectSQL) <= 0 {
|
if len(selectSQL) <= 0 {
|
||||||
|
@ -1180,10 +1194,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}
|
||||||
selectSQL = "count(*)"
|
selectSQL = "count(*)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...)
|
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
|
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
|
||||||
statement.setRefValue(rValue(bean))
|
statement.setRefValue(rValue(bean))
|
||||||
|
|
||||||
var sumStrs = make([]string, 0, len(columns))
|
var sumStrs = make([]string, 0, len(columns))
|
||||||
|
@ -1195,12 +1214,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
||||||
}
|
}
|
||||||
sumSelect := strings.Join(sumStrs, ", ")
|
sumSelect := strings.Join(sumStrs, ", ")
|
||||||
|
|
||||||
condSQL, condArgs, _ := statement.genConds(bean)
|
condSQL, condArgs, err := statement.genConds(bean)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...)
|
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
|
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
|
||||||
var distinct string
|
var distinct string
|
||||||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
||||||
distinct = "DISTINCT "
|
distinct = "DISTINCT "
|
||||||
|
@ -1211,7 +1238,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
|
||||||
var top string
|
var top string
|
||||||
var mssqlCondi string
|
var mssqlCondi string
|
||||||
|
|
||||||
statement.processIDParam()
|
if err := statement.processIDParam(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
if len(condSQL) > 0 {
|
if len(condSQL) > 0 {
|
||||||
|
@ -1314,19 +1343,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) processIDParam() {
|
func (statement *Statement) processIDParam() error {
|
||||||
if statement.idParam == nil {
|
if statement.idParam == nil {
|
||||||
return
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
|
||||||
|
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
|
||||||
|
len(statement.RefTable.PrimaryKeys),
|
||||||
|
len(*statement.idParam),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, col := range statement.RefTable.PKColumns() {
|
for i, col := range statement.RefTable.PKColumns() {
|
||||||
var colName = statement.colName(col, statement.TableName())
|
var colName = statement.colName(col, statement.TableName())
|
||||||
if i < len(*(statement.idParam)) {
|
|
||||||
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
|
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
|
||||||
} else {
|
|
||||||
statement.cond = statement.cond.And(builder.Eq{colName: ""})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
|
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
|
||||||
|
|
Loading…
Reference in New Issue