fix bug and catch more tests (#613)
This commit is contained in:
parent
7a9bf19c65
commit
1e055bac01
9
rows.go
9
rows.go
|
@ -33,8 +33,9 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
|
||||
if err := rows.session.Statement.setRefValue(rValue(bean)); err != nil {
|
||||
if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -43,7 +44,10 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
|||
}
|
||||
|
||||
if rows.session.Statement.RawSQL == "" {
|
||||
sqlStr, args = rows.session.Statement.genGetSQL(bean)
|
||||
sqlStr, args, err = rows.session.Statement.genGetSQL(bean)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
sqlStr = rows.session.Statement.RawSQL
|
||||
args = rows.session.Statement.RawParams
|
||||
|
@ -54,7 +58,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
|
|||
}
|
||||
|
||||
rows.session.saveLastSQL(sqlStr, args...)
|
||||
var err error
|
||||
if rows.session.prepareStmt {
|
||||
rows.stmt, err = rows.session.DB().Prepare(sqlStr)
|
||||
if err != nil {
|
||||
|
|
|
@ -48,6 +48,8 @@ type Session struct {
|
|||
//beforeSQLExec func(string, ...interface{})
|
||||
lastSQL string
|
||||
lastSQLArgs []interface{}
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
// Clone copy all the session's content and return a new session
|
||||
|
@ -620,7 +622,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
|
|||
structInter := reflect.New(fieldValue.Type())
|
||||
newsession := session.Engine.NewSession()
|
||||
defer newsession.Close()
|
||||
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface())
|
||||
has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -98,8 +98,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
|||
processor.BeforeDelete()
|
||||
}
|
||||
|
||||
// --
|
||||
condSQL, condArgs, _ := session.Statement.genConds(bean)
|
||||
condSQL, condArgs, err := session.Statement.genConds(bean)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(condSQL) == 0 && session.Statement.LimitN == 0 {
|
||||
return 0, ErrNeedDeletedCond
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ func TestDelete(t *testing.T) {
|
|||
assert.NoError(t, prepareEngine())
|
||||
|
||||
type UserinfoDelete struct {
|
||||
Uid int64
|
||||
Uid int64 `xorm:"id pk not null autoincr"`
|
||||
IsMan bool
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
if session.Statement.RawSQL == "" {
|
||||
if len(session.Statement.TableName()) <= 0 {
|
||||
return ErrTableNotFound
|
||||
|
@ -128,7 +129,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
}
|
||||
|
||||
args = append(session.Statement.joinArgs, condArgs...)
|
||||
sqlStr = session.Statement.genSelectSQL(columnStr, condSQL)
|
||||
sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// for mssql and use limit
|
||||
qs := strings.Count(sqlStr, "?")
|
||||
if len(args)*2 == qs {
|
||||
|
@ -139,7 +143,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
|
|||
args = session.Statement.RawParams
|
||||
}
|
||||
|
||||
var err error
|
||||
if session.canCache() {
|
||||
if cacher := session.Engine.getCacher2(table); cacher != nil &&
|
||||
!session.Statement.IsDistinct &&
|
||||
|
|
|
@ -33,13 +33,17 @@ func (session *Session) Get(bean interface{}) (bool, error) {
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
|
||||
if session.Statement.RawSQL == "" {
|
||||
if len(session.Statement.TableName()) <= 0 {
|
||||
return false, ErrTableNotFound
|
||||
}
|
||||
session.Statement.Limit(1)
|
||||
sqlStr, args = session.Statement.genGetSQL(bean)
|
||||
sqlStr, args, err = session.Statement.genGetSQL(bean)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
sqlStr = session.Statement.RawSQL
|
||||
args = session.Statement.RawParams
|
||||
|
|
|
@ -1139,3 +1139,29 @@ func TestCompositePK(t *testing.T) {
|
|||
assert.EqualValues(t, "uid", pkCols[0].Name)
|
||||
assert.EqualValues(t, "tid", pkCols[1].Name)
|
||||
}
|
||||
|
||||
func TestNoPKIdQueryUpdate(t *testing.T) {
|
||||
type NoPKTable struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
assert.NoError(t, prepareEngine())
|
||||
assertSync(t, new(NoPKTable))
|
||||
|
||||
cnt, err := testEngine.Insert(&NoPKTable{
|
||||
Username: "test",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
||||
var res NoPKTable
|
||||
has, err := testEngine.ID("test").Get(&res)
|
||||
assert.Error(t, err)
|
||||
assert.False(t, has)
|
||||
|
||||
cnt, err = testEngine.ID("test").Update(&NoPKTable{
|
||||
Username: "test1",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.EqualValues(t, 0, cnt)
|
||||
}
|
||||
|
|
|
@ -16,11 +16,15 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
if session.Statement.RawSQL == "" {
|
||||
if len(bean) == 0 {
|
||||
return 0, ErrTableNotFound
|
||||
}
|
||||
sqlStr, args = session.Statement.genCountSQL(bean[0])
|
||||
sqlStr, args, err = session.Statement.genCountSQL(bean[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
sqlStr = session.Statement.RawSQL
|
||||
args = session.Statement.RawParams
|
||||
|
@ -28,7 +32,6 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
|
|||
|
||||
session.queryPreprocess(&sqlStr, args...)
|
||||
|
||||
var err error
|
||||
var total int64
|
||||
if session.IsAutoCommit {
|
||||
err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
|
||||
|
@ -52,8 +55,12 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
if len(session.Statement.RawSQL) == 0 {
|
||||
sqlStr, args = session.Statement.genSumSQL(bean, columnName)
|
||||
sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
sqlStr = session.Statement.RawSQL
|
||||
args = session.Statement.RawParams
|
||||
|
@ -61,7 +68,6 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
|
|||
|
||||
session.queryPreprocess(&sqlStr, args...)
|
||||
|
||||
var err error
|
||||
var res float64
|
||||
if session.IsAutoCommit {
|
||||
err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
|
||||
|
@ -84,8 +90,12 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
if len(session.Statement.RawSQL) == 0 {
|
||||
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
|
||||
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
sqlStr = session.Statement.RawSQL
|
||||
args = session.Statement.RawParams
|
||||
|
@ -93,7 +103,6 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
|
|||
|
||||
session.queryPreprocess(&sqlStr, args...)
|
||||
|
||||
var err error
|
||||
var res = make([]float64, len(columnNames), len(columnNames))
|
||||
if session.IsAutoCommit {
|
||||
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
|
||||
|
@ -116,8 +125,12 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
|
|||
|
||||
var sqlStr string
|
||||
var args []interface{}
|
||||
var err error
|
||||
if len(session.Statement.RawSQL) == 0 {
|
||||
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
|
||||
sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
sqlStr = session.Statement.RawSQL
|
||||
args = session.Statement.RawParams
|
||||
|
@ -125,7 +138,6 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
|
|||
|
||||
session.queryPreprocess(&sqlStr, args...)
|
||||
|
||||
var err error
|
||||
var res = make([]int64, len(columnNames), len(columnNames))
|
||||
if session.IsAutoCommit {
|
||||
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
|
||||
|
|
|
@ -236,7 +236,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
|
||||
}
|
||||
|
||||
session.Statement.processIDParam()
|
||||
if err = session.Statement.processIDParam(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var autoCond builder.Cond
|
||||
if !session.Statement.noAutoCondition && len(condiBean) > 0 {
|
||||
|
@ -267,7 +269,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
|
||||
}
|
||||
|
||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(condSQL) > 0 {
|
||||
condSQL = "WHERE " + condSQL
|
||||
}
|
||||
|
@ -285,7 +291,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
|
||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(condSQL) > 0 {
|
||||
condSQL = "WHERE " + condSQL
|
||||
}
|
||||
|
@ -293,7 +302,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
|
||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(condSQL) > 0 {
|
||||
condSQL = "WHERE " + condSQL
|
||||
}
|
||||
|
@ -304,7 +317,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
|
||||
session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...)
|
||||
|
||||
condSQL, condArgs, _ = builder.ToSQL(cond)
|
||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(condSQL) > 0 {
|
||||
condSQL = "WHERE " + condSQL
|
||||
}
|
||||
|
|
73
statement.go
73
statement.go
|
@ -1118,12 +1118,14 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
|
|||
statement.cond = statement.cond.And(autoCond)
|
||||
}
|
||||
|
||||
statement.processIDParam()
|
||||
if err := statement.processIDParam(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return builder.ToSQL(statement.cond)
|
||||
}
|
||||
|
||||
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
|
||||
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
|
||||
v := rValue(bean)
|
||||
isStruct := v.Kind() == reflect.Struct
|
||||
if isStruct {
|
||||
|
@ -1158,19 +1160,31 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
|
|||
|
||||
var condSQL string
|
||||
var condArgs []interface{}
|
||||
var err error
|
||||
if isStruct {
|
||||
condSQL, condArgs, _ = statement.genConds(bean)
|
||||
condSQL, condArgs, err = statement.genConds(bean)
|
||||
} else {
|
||||
condSQL, condArgs, _ = builder.ToSQL(statement.cond)
|
||||
condSQL, condArgs, err = builder.ToSQL(statement.cond)
|
||||
}
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
|
||||
sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
|
||||
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||
}
|
||||
|
||||
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) {
|
||||
statement.setRefValue(rValue(bean))
|
||||
|
||||
condSQL, condArgs, _ := statement.genConds(bean)
|
||||
condSQL, condArgs, err := statement.genConds(bean)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
var selectSQL = statement.selectStr
|
||||
if len(selectSQL) <= 0 {
|
||||
|
@ -1180,10 +1194,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}
|
|||
selectSQL = "count(*)"
|
||||
}
|
||||
}
|
||||
return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...)
|
||||
sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
|
||||
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||
}
|
||||
|
||||
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
|
||||
statement.setRefValue(rValue(bean))
|
||||
|
||||
var sumStrs = make([]string, 0, len(columns))
|
||||
|
@ -1195,12 +1214,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
|
|||
}
|
||||
sumSelect := strings.Join(sumStrs, ", ")
|
||||
|
||||
condSQL, condArgs, _ := statement.genConds(bean)
|
||||
|
||||
return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...)
|
||||
condSQL, condArgs, err := statement.genConds(bean)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
|
||||
sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return sqlStr, append(statement.joinArgs, condArgs...), nil
|
||||
}
|
||||
|
||||
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
|
||||
var distinct string
|
||||
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
|
||||
distinct = "DISTINCT "
|
||||
|
@ -1211,7 +1238,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
|
|||
var top string
|
||||
var mssqlCondi string
|
||||
|
||||
statement.processIDParam()
|
||||
if err := statement.processIDParam(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if len(condSQL) > 0 {
|
||||
|
@ -1314,19 +1343,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (statement *Statement) processIDParam() {
|
||||
func (statement *Statement) processIDParam() error {
|
||||
if statement.idParam == nil {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
|
||||
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
|
||||
len(statement.RefTable.PrimaryKeys),
|
||||
len(*statement.idParam),
|
||||
)
|
||||
}
|
||||
|
||||
for i, col := range statement.RefTable.PKColumns() {
|
||||
var colName = statement.colName(col, statement.TableName())
|
||||
if i < len(*(statement.idParam)) {
|
||||
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
|
||||
} else {
|
||||
statement.cond = statement.cond.And(builder.Eq{colName: ""})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
|
||||
|
|
Loading…
Reference in New Issue