diff --git a/session_insert.go b/session_insert.go index d4555730..07a63648 100644 --- a/session_insert.go +++ b/session_insert.go @@ -311,8 +311,15 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { var newArgs []interface{} var needCommit bool var id int64 + var i = utils.IndexSlice(colNames, table.AutoIncrement) + if i > -1 { + id, err = convert.AsInt64(args[i]) + if err != nil { + return 0, err + } + } if session.engine.dialect.URI().DBType == schemas.DB2 || session.engine.dialect.URI().DBType == schemas.ORACLE { - if session.isAutoCommit { // if it's not in transaction + if id == 0 && session.isAutoCommit { // if it's not in transaction if err := session.Begin(); err != nil { return 0, err } @@ -320,20 +327,16 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { } _, err := session.exec(sqlStr, args...) if err != nil { + if needCommit { + session.Rollback() + } return 0, err } - i := utils.IndexSlice(colNames, table.AutoIncrement) - if i > -1 { - id, err = convert.AsInt64(args[i]) - if err != nil { - return 0, err - } - } else { - if session.engine.dialect.URI().DBType == schemas.ORACLE { - sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) - } else if session.engine.dialect.URI().DBType == schemas.DB2 { - sql = "select IDENTITY_VAL_LOCAL() as id FROM sysibm.sysdummy1" - } + + if session.engine.dialect.URI().DBType == schemas.ORACLE { + sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) + } else if session.engine.dialect.URI().DBType == schemas.DB2 { + sql = "select IDENTITY_VAL_LOCAL() as id FROM sysibm.sysdummy1" } } else { sql = sqlStr @@ -343,6 +346,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { if id == 0 { err := session.queryRow(sql, newArgs...).Scan(&id) if err != nil { + if needCommit { + session.Rollback() + } return 0, err } if needCommit {