diff --git a/error.go b/error.go index a223fc4a..a67527ac 100644 --- a/error.go +++ b/error.go @@ -26,6 +26,8 @@ var ( ErrNotImplemented = errors.New("Not implemented") // ErrConditionType condition type unsupported ErrConditionType = errors.New("Unsupported condition type") + // ErrUnSupportedSQLType parameter of SQL is not supported + ErrUnSupportedSQLType = errors.New("unsupported sql type") ) // ErrFieldIsNotExist columns does not exist diff --git a/session_delete.go b/session_delete.go index dcce543a..26782f69 100644 --- a/session_delete.go +++ b/session_delete.go @@ -79,6 +79,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } + if session.statement.lastError != nil { + return 0, session.statement.lastError + } + if err := session.statement.setRefBean(bean); err != nil { return 0, err } diff --git a/session_exist.go b/session_exist.go index 74a660e8..df205d33 100644 --- a/session_exist.go +++ b/session_exist.go @@ -19,6 +19,10 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { defer session.Close() } + if session.statement.lastError != nil { + return false, session.statement.lastError + } + var sqlStr string var args []interface{} var err error diff --git a/session_find.go b/session_find.go index a5b4f793..48ee3209 100644 --- a/session_find.go +++ b/session_find.go @@ -63,6 +63,10 @@ func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...inte } func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { + if session.statement.lastError != nil { + return session.statement.lastError + } + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { return errors.New("needs a pointer to a slice or a map") diff --git a/session_get.go b/session_get.go index 1cea31c5..5ecf2f37 100644 --- a/session_get.go +++ b/session_get.go @@ -24,6 +24,10 @@ func (session *Session) Get(bean interface{}) (bool, error) { } func (session *Session) get(bean interface{}) (bool, error) { + if session.statement.lastError != nil { + return false, session.statement.lastError + } + beanValue := reflect.ValueOf(bean) if beanValue.Kind() != reflect.Ptr { return false, errors.New("needs a pointer to a value") diff --git a/session_iterate.go b/session_iterate.go index 071fce49..ca996c28 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -23,6 +23,10 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { defer session.Close() } + if session.statement.lastError != nil { + return session.statement.lastError + } + if session.statement.bufferSize > 0 { return session.bufferIterate(bean, fun) } diff --git a/session_update.go b/session_update.go index 37b34ff3..6bd16aaf 100644 --- a/session_update.go +++ b/session_update.go @@ -147,6 +147,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } + if session.statement.lastError != nil { + return 0, session.statement.lastError + } + v := rValue(bean) t := v.Type() diff --git a/session_update_test.go b/session_update_test.go index 2a7005ee..53fdd270 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -1331,3 +1331,21 @@ func TestUpdateCondiBean(t *testing.T) { assert.NoError(t, err) assert.True(t, has) } + +func TestWhereCondErrorWhenUpdate(t *testing.T) { + type AuthRequestError struct { + ChallengeToken string + RequestToken string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(AuthRequestError)) + + _, err := testEngine.Cols("challenge_token", "request_token", "challenge_agent", "status"). + Where(&AuthRequestError{ChallengeToken: "1"}). + Update(&AuthRequestError{ + ChallengeToken: "2", + }) + assert.Error(t, err) + assert.EqualValues(t, ErrConditionType, err) +} diff --git a/statement.go b/statement.go index a7f7010a..ec43ce56 100644 --- a/statement.go +++ b/statement.go @@ -60,6 +60,7 @@ type Statement struct { cond builder.Cond bufferSize int context ContextCache + lastError error } // Init reset all the statement's fields @@ -101,6 +102,7 @@ func (statement *Statement) Init() { statement.cond = builder.NewCond() statement.bufferSize = 0 statement.context = nil + statement.lastError = nil } // NoAutoCondition if you do not want convert bean's field as query condition, then use this function @@ -125,13 +127,13 @@ func (statement *Statement) SQL(query interface{}, args ...interface{}) *Stateme var err error statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() if err != nil { - statement.Engine.logger.Error(err) + statement.lastError = err } case string: statement.RawSQL = query.(string) statement.RawParams = args default: - statement.Engine.logger.Error("unsupported sql type") + statement.lastError = ErrUnSupportedSQLType } return statement @@ -160,7 +162,7 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme } } default: - // TODO: not support condition type + statement.lastError = ErrConditionType } return statement