From 1e055bac01000c2c365beb5ba6fb778a65b6376e Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 8 Jun 2017 19:38:52 +0800 Subject: [PATCH] fix bug and catch more tests (#613) --- rows.go | 9 ++++-- session.go | 4 ++- session_delete.go | 6 ++-- session_delete_test.go | 2 +- session_find.go | 7 ++-- session_get.go | 6 +++- session_pk_test.go | 26 +++++++++++++++ session_sum.go | 28 +++++++++++----- session_update.go | 26 ++++++++++++--- statement.go | 73 ++++++++++++++++++++++++++++++------------ 10 files changed, 144 insertions(+), 43 deletions(-) diff --git a/rows.go b/rows.go index 47bc322f..5aa4ffc3 100644 --- a/rows.go +++ b/rows.go @@ -33,8 +33,9 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var sqlStr string var args []interface{} + var err error - if err := rows.session.Statement.setRefValue(rValue(bean)); err != nil { + if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil { return nil, err } @@ -43,7 +44,10 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { } if rows.session.Statement.RawSQL == "" { - sqlStr, args = rows.session.Statement.genGetSQL(bean) + sqlStr, args, err = rows.session.Statement.genGetSQL(bean) + if err != nil { + return nil, err + } } else { sqlStr = rows.session.Statement.RawSQL args = rows.session.Statement.RawParams @@ -54,7 +58,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { } rows.session.saveLastSQL(sqlStr, args...) - var err error if rows.session.prepareStmt { rows.stmt, err = rows.session.DB().Prepare(sqlStr) if err != nil { diff --git a/session.go b/session.go index bbe56adc..afcab3c9 100644 --- a/session.go +++ b/session.go @@ -48,6 +48,8 @@ type Session struct { //beforeSQLExec func(string, ...interface{}) lastSQL string lastSQLArgs []interface{} + + err error } // Clone copy all the session's content and return a new session @@ -620,7 +622,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i structInter := reflect.New(fieldValue.Type()) newsession := session.Engine.NewSession() defer newsession.Close() - has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) + has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface()) if err != nil { return nil, err } diff --git a/session_delete.go b/session_delete.go index 0c1e705e..044e9d61 100644 --- a/session_delete.go +++ b/session_delete.go @@ -98,8 +98,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) { processor.BeforeDelete() } - // -- - condSQL, condArgs, _ := session.Statement.genConds(bean) + condSQL, condArgs, err := session.Statement.genConds(bean) + if err != nil { + return 0, err + } if len(condSQL) == 0 && session.Statement.LimitN == 0 { return 0, ErrNeedDeletedCond } diff --git a/session_delete_test.go b/session_delete_test.go index 6522a139..27e61321 100644 --- a/session_delete_test.go +++ b/session_delete_test.go @@ -15,7 +15,7 @@ func TestDelete(t *testing.T) { assert.NoError(t, prepareEngine()) type UserinfoDelete struct { - Uid int64 + Uid int64 `xorm:"id pk not null autoincr"` IsMan bool } diff --git a/session_find.go b/session_find.go index 9ee37201..9b8b31ef 100644 --- a/session_find.go +++ b/session_find.go @@ -91,6 +91,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var sqlStr string var args []interface{} + var err error if session.Statement.RawSQL == "" { if len(session.Statement.TableName()) <= 0 { return ErrTableNotFound @@ -128,7 +129,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } args = append(session.Statement.joinArgs, condArgs...) - sqlStr = session.Statement.genSelectSQL(columnStr, condSQL) + sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL) + if err != nil { + return err + } // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { @@ -139,7 +143,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) args = session.Statement.RawParams } - var err error if session.canCache() { if cacher := session.Engine.getCacher2(table); cacher != nil && !session.Statement.IsDistinct && diff --git a/session_get.go b/session_get.go index c7c03d90..64203513 100644 --- a/session_get.go +++ b/session_get.go @@ -33,13 +33,17 @@ func (session *Session) Get(bean interface{}) (bool, error) { var sqlStr string var args []interface{} + var err error if session.Statement.RawSQL == "" { if len(session.Statement.TableName()) <= 0 { return false, ErrTableNotFound } session.Statement.Limit(1) - sqlStr, args = session.Statement.genGetSQL(bean) + sqlStr, args, err = session.Statement.genGetSQL(bean) + if err != nil { + return false, err + } } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams diff --git a/session_pk_test.go b/session_pk_test.go index 0dff9df7..0e17352c 100644 --- a/session_pk_test.go +++ b/session_pk_test.go @@ -1139,3 +1139,29 @@ func TestCompositePK(t *testing.T) { assert.EqualValues(t, "uid", pkCols[0].Name) assert.EqualValues(t, "tid", pkCols[1].Name) } + +func TestNoPKIdQueryUpdate(t *testing.T) { + type NoPKTable struct { + Username string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(NoPKTable)) + + cnt, err := testEngine.Insert(&NoPKTable{ + Username: "test", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var res NoPKTable + has, err := testEngine.ID("test").Get(&res) + assert.Error(t, err) + assert.False(t, has) + + cnt, err = testEngine.ID("test").Update(&NoPKTable{ + Username: "test1", + }) + assert.Error(t, err) + assert.EqualValues(t, 0, cnt) +} diff --git a/session_sum.go b/session_sum.go index 8b2d38c2..b0663708 100644 --- a/session_sum.go +++ b/session_sum.go @@ -16,11 +16,15 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { var sqlStr string var args []interface{} + var err error if session.Statement.RawSQL == "" { if len(bean) == 0 { return 0, ErrTableNotFound } - sqlStr, args = session.Statement.genCountSQL(bean[0]) + sqlStr, args, err = session.Statement.genCountSQL(bean[0]) + if err != nil { + return 0, err + } } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -28,7 +32,6 @@ func (session *Session) Count(bean ...interface{}) (int64, error) { session.queryPreprocess(&sqlStr, args...) - var err error var total int64 if session.IsAutoCommit { err = session.DB().QueryRow(sqlStr, args...).Scan(&total) @@ -52,8 +55,12 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error var sqlStr string var args []interface{} + var err error if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSQL(bean, columnName) + sqlStr, args, err = session.Statement.genSumSQL(bean, columnName) + if err != nil { + return 0, err + } } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -61,7 +68,6 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error session.queryPreprocess(&sqlStr, args...) - var err error var res float64 if session.IsAutoCommit { err = session.DB().QueryRow(sqlStr, args...).Scan(&res) @@ -84,8 +90,12 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64 var sqlStr string var args []interface{} + var err error if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) + sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...) + if err != nil { + return nil, err + } } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -93,7 +103,6 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64 session.queryPreprocess(&sqlStr, args...) - var err error var res = make([]float64, len(columnNames), len(columnNames)) if session.IsAutoCommit { err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) @@ -116,8 +125,12 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6 var sqlStr string var args []interface{} + var err error if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) + sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...) + if err != nil { + return nil, err + } } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -125,7 +138,6 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6 session.queryPreprocess(&sqlStr, args...) - var err error var res = make([]int64, len(columnNames), len(columnNames)) if session.IsAutoCommit { err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) diff --git a/session_update.go b/session_update.go index 7cb38c22..792fb574 100644 --- a/session_update.go +++ b/session_update.go @@ -236,7 +236,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr) } - session.Statement.processIDParam() + if err = session.Statement.processIDParam(); err != nil { + return 0, err + } var autoCond builder.Cond if !session.Statement.noAutoCondition && len(condiBean) > 0 { @@ -267,7 +269,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") } - condSQL, condArgs, _ = builder.ToSQL(cond) + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } + if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } @@ -285,7 +291,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) - condSQL, condArgs, _ = builder.ToSQL(cond) + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } @@ -293,7 +302,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) - condSQL, condArgs, _ = builder.ToSQL(cond) + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } + if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } @@ -304,7 +317,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...) - condSQL, condArgs, _ = builder.ToSQL(cond) + condSQL, condArgs, err = builder.ToSQL(cond) + if err != nil { + return 0, err + } if len(condSQL) > 0 { condSQL = "WHERE " + condSQL } diff --git a/statement.go b/statement.go index 58fa616b..80fad4f4 100644 --- a/statement.go +++ b/statement.go @@ -1118,12 +1118,14 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e statement.cond = statement.cond.And(autoCond) } - statement.processIDParam() + if err := statement.processIDParam(); err != nil { + return "", nil, err + } return builder.ToSQL(statement.cond) } -func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { +func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) { v := rValue(bean) isStruct := v.Kind() == reflect.Struct if isStruct { @@ -1158,19 +1160,31 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) var condSQL string var condArgs []interface{} + var err error if isStruct { - condSQL, condArgs, _ = statement.genConds(bean) + condSQL, condArgs, err = statement.genConds(bean) } else { - condSQL, condArgs, _ = builder.ToSQL(statement.cond) + condSQL, condArgs, err = builder.ToSQL(statement.cond) + } + if err != nil { + return "", nil, err } - return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(columnStr, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) { +func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) { statement.setRefValue(rValue(bean)) - condSQL, condArgs, _ := statement.genConds(bean) + condSQL, condArgs, err := statement.genConds(bean) + if err != nil { + return "", nil, err + } var selectSQL = statement.selectStr if len(selectSQL) <= 0 { @@ -1180,10 +1194,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{} selectSQL = "count(*)" } } - return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) { +func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { statement.setRefValue(rValue(bean)) var sumStrs = make([]string, 0, len(columns)) @@ -1195,12 +1214,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri } sumSelect := strings.Join(sumStrs, ", ") - condSQL, condArgs, _ := statement.genConds(bean) + condSQL, condArgs, err := statement.genConds(bean) + if err != nil { + return "", nil, err + } - return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...) + sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) + if err != nil { + return "", nil, err + } + + return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { +func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { var distinct string if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " @@ -1211,7 +1238,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { var top string var mssqlCondi string - statement.processIDParam() + if err := statement.processIDParam(); err != nil { + return "", err + } var buf bytes.Buffer if len(condSQL) > 0 { @@ -1314,19 +1343,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { return } -func (statement *Statement) processIDParam() { +func (statement *Statement) processIDParam() error { if statement.idParam == nil { - return + return nil + } + + if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) { + return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d", + len(statement.RefTable.PrimaryKeys), + len(*statement.idParam), + ) } for i, col := range statement.RefTable.PKColumns() { var colName = statement.colName(col, statement.TableName()) - if i < len(*(statement.idParam)) { - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) - } else { - statement.cond = statement.cond.And(builder.Eq{colName: ""}) - } + statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]}) } + return nil } func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {