From e502385b12496b3ac7940d2a74891e6f2ad1c485 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 12 Aug 2021 16:19:11 +0800 Subject: [PATCH] Fix tests --- dialects/db2.go | 6 +-- integrations/session_cols_test.go | 6 +-- internal/utils/name.go | 6 +++ internal/utils/slice.go | 9 +++++ schemas/type.go | 1 + session_insert.go | 61 ++++++++++++++++++++++--------- 6 files changed, 64 insertions(+), 25 deletions(-) diff --git a/dialects/db2.go b/dialects/db2.go index 0a106d06..74f2d750 100644 --- a/dialects/db2.go +++ b/dialects/db2.go @@ -61,7 +61,7 @@ func (db *db2) SQLType(c *schemas.Column) string { res = schemas.BigInt case schemas.UnsignedInt: res = schemas.BigInt - case schemas.Bit: + case schemas.Bit, schemas.Bool, schemas.Boolean: res = schemas.Boolean return res case schemas.Binary, schemas.VarBinary: @@ -82,10 +82,6 @@ func (db *db2) SQLType(c *schemas.Column) string { res = t } - if strings.EqualFold(res, "bool") { - // for bool, we don't need length information - return res - } hasLen1 := (c.Length > 0) hasLen2 := (c.Length2 > 0) diff --git a/integrations/session_cols_test.go b/integrations/session_cols_test.go index b74c6f8a..1b91e0bb 100644 --- a/integrations/session_cols_test.go +++ b/integrations/session_cols_test.go @@ -52,11 +52,11 @@ func TestSetExpr(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 1, cnt) - tableName := testEngine.TableName(new(UserExprIssue), true) + tableName := testEngine.Quote(testEngine.TableName(new(UserExprIssue), true)) cnt, err = testEngine.SetExpr("issue_id", - builder.Select("id"). + builder.Select("`id`"). From(tableName). - Where(builder.Eq{"id": issue.Id})). + Where(builder.Eq{"`id`": issue.Id})). ID(1). Update(new(UserExpr)) assert.NoError(t, err) diff --git a/internal/utils/name.go b/internal/utils/name.go index 840dd9e8..aeef683d 100644 --- a/internal/utils/name.go +++ b/internal/utils/name.go @@ -6,9 +6,15 @@ package utils import ( "fmt" + "strings" ) // IndexName returns index name func IndexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } + +// SeqName returns sequence name for some table +func SeqName(tableName string) string { + return "SEQ_" + strings.ToUpper(tableName) +} diff --git a/internal/utils/slice.go b/internal/utils/slice.go index 89685706..b568f6f5 100644 --- a/internal/utils/slice.go +++ b/internal/utils/slice.go @@ -20,3 +20,12 @@ func SliceEq(left, right []string) bool { } return true } + +func IndexSlice(s []string, c string) int { + for i, ss := range s { + if c == ss { + return i + } + } + return -1 +} diff --git a/schemas/type.go b/schemas/type.go index cf730134..c66824a6 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -22,6 +22,7 @@ const ( MYSQL DBType = "mysql" MSSQL DBType = "mssql" ORACLE DBType = "oracle" + DB2 DBType = "db2" ) // SQLType represents SQL types diff --git a/session_insert.go b/session_insert.go index a8f365c7..d4555730 100644 --- a/session_insert.go +++ b/session_insert.go @@ -307,16 +307,53 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { // if there is auto increment column and driver don't support return it if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { - var sql = sqlStr - if session.engine.dialect.URI().DBType == schemas.ORACLE { - sql = "select seq_atable.currval from dual" + var sql string + var newArgs []interface{} + var needCommit bool + var id int64 + if session.engine.dialect.URI().DBType == schemas.DB2 || session.engine.dialect.URI().DBType == schemas.ORACLE { + if session.isAutoCommit { // if it's not in transaction + if err := session.Begin(); err != nil { + return 0, err + } + needCommit = true + } + _, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } + i := utils.IndexSlice(colNames, table.AutoIncrement) + if i > -1 { + id, err = convert.AsInt64(args[i]) + if err != nil { + return 0, err + } + } else { + if session.engine.dialect.URI().DBType == schemas.ORACLE { + sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName)) + } else if session.engine.dialect.URI().DBType == schemas.DB2 { + sql = "select IDENTITY_VAL_LOCAL() as id FROM sysibm.sysdummy1" + } + } + } else { + sql = sqlStr + newArgs = args } - rows, err := session.queryRows(sql, args...) - if err != nil { - return 0, err + if id == 0 { + err := session.queryRow(sql, newArgs...).Scan(&id) + if err != nil { + return 0, err + } + if needCommit { + if err := session.Commit(); err != nil { + return 0, err + } + } + if id == 0 { + return 0, errors.New("insert successfully but not returned id") + } } - defer rows.Close() defer handleAfterInsertProcessorFunc(bean) @@ -331,16 +368,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { } } - var id int64 - if !rows.Next() { - if rows.Err() != nil { - return 0, rows.Err() - } - return 0, errors.New("insert successfully but not returned id") - } - if err := rows.Scan(&id); err != nil { - return 1, err - } aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { session.engine.logger.Errorf("%v", err)