Fix tests
This commit is contained in:
parent
3e45120160
commit
e502385b12
|
@ -61,7 +61,7 @@ func (db *db2) SQLType(c *schemas.Column) string {
|
||||||
res = schemas.BigInt
|
res = schemas.BigInt
|
||||||
case schemas.UnsignedInt:
|
case schemas.UnsignedInt:
|
||||||
res = schemas.BigInt
|
res = schemas.BigInt
|
||||||
case schemas.Bit:
|
case schemas.Bit, schemas.Bool, schemas.Boolean:
|
||||||
res = schemas.Boolean
|
res = schemas.Boolean
|
||||||
return res
|
return res
|
||||||
case schemas.Binary, schemas.VarBinary:
|
case schemas.Binary, schemas.VarBinary:
|
||||||
|
@ -82,10 +82,6 @@ func (db *db2) SQLType(c *schemas.Column) string {
|
||||||
res = t
|
res = t
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.EqualFold(res, "bool") {
|
|
||||||
// for bool, we don't need length information
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
hasLen1 := (c.Length > 0)
|
hasLen1 := (c.Length > 0)
|
||||||
hasLen2 := (c.Length2 > 0)
|
hasLen2 := (c.Length2 > 0)
|
||||||
|
|
||||||
|
|
|
@ -52,11 +52,11 @@ func TestSetExpr(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 1, cnt)
|
assert.EqualValues(t, 1, cnt)
|
||||||
|
|
||||||
tableName := testEngine.TableName(new(UserExprIssue), true)
|
tableName := testEngine.Quote(testEngine.TableName(new(UserExprIssue), true))
|
||||||
cnt, err = testEngine.SetExpr("issue_id",
|
cnt, err = testEngine.SetExpr("issue_id",
|
||||||
builder.Select("id").
|
builder.Select("`id`").
|
||||||
From(tableName).
|
From(tableName).
|
||||||
Where(builder.Eq{"id": issue.Id})).
|
Where(builder.Eq{"`id`": issue.Id})).
|
||||||
ID(1).
|
ID(1).
|
||||||
Update(new(UserExpr))
|
Update(new(UserExpr))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
|
@ -6,9 +6,15 @@ package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IndexName returns index name
|
// IndexName returns index name
|
||||||
func IndexName(tableName, idxName string) string {
|
func IndexName(tableName, idxName string) string {
|
||||||
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SeqName returns sequence name for some table
|
||||||
|
func SeqName(tableName string) string {
|
||||||
|
return "SEQ_" + strings.ToUpper(tableName)
|
||||||
|
}
|
||||||
|
|
|
@ -20,3 +20,12 @@ func SliceEq(left, right []string) bool {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IndexSlice(s []string, c string) int {
|
||||||
|
for i, ss := range s {
|
||||||
|
if c == ss {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ const (
|
||||||
MYSQL DBType = "mysql"
|
MYSQL DBType = "mysql"
|
||||||
MSSQL DBType = "mssql"
|
MSSQL DBType = "mssql"
|
||||||
ORACLE DBType = "oracle"
|
ORACLE DBType = "oracle"
|
||||||
|
DB2 DBType = "db2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SQLType represents SQL types
|
// SQLType represents SQL types
|
||||||
|
|
|
@ -307,16 +307,53 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
|
||||||
|
|
||||||
// if there is auto increment column and driver don't support return it
|
// if there is auto increment column and driver don't support return it
|
||||||
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
|
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
|
||||||
var sql = sqlStr
|
var sql string
|
||||||
if session.engine.dialect.URI().DBType == schemas.ORACLE {
|
var newArgs []interface{}
|
||||||
sql = "select seq_atable.currval from dual"
|
var needCommit bool
|
||||||
|
var id int64
|
||||||
|
if session.engine.dialect.URI().DBType == schemas.DB2 || session.engine.dialect.URI().DBType == schemas.ORACLE {
|
||||||
|
if session.isAutoCommit { // if it's not in transaction
|
||||||
|
if err := session.Begin(); err != nil {
|
||||||
|
return 0, err
|
||||||
}
|
}
|
||||||
|
needCommit = true
|
||||||
rows, err := session.queryRows(sql, args...)
|
}
|
||||||
|
_, err := session.exec(sqlStr, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
i := utils.IndexSlice(colNames, table.AutoIncrement)
|
||||||
|
if i > -1 {
|
||||||
|
id, err = convert.AsInt64(args[i])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if session.engine.dialect.URI().DBType == schemas.ORACLE {
|
||||||
|
sql = fmt.Sprintf("select %s.currval from dual", utils.SeqName(tableName))
|
||||||
|
} else if session.engine.dialect.URI().DBType == schemas.DB2 {
|
||||||
|
sql = "select IDENTITY_VAL_LOCAL() as id FROM sysibm.sysdummy1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sql = sqlStr
|
||||||
|
newArgs = args
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == 0 {
|
||||||
|
err := session.queryRow(sql, newArgs...).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if needCommit {
|
||||||
|
if err := session.Commit(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if id == 0 {
|
||||||
|
return 0, errors.New("insert successfully but not returned id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defer handleAfterInsertProcessorFunc(bean)
|
defer handleAfterInsertProcessorFunc(bean)
|
||||||
|
|
||||||
|
@ -331,16 +368,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var id int64
|
|
||||||
if !rows.Next() {
|
|
||||||
if rows.Err() != nil {
|
|
||||||
return 0, rows.Err()
|
|
||||||
}
|
|
||||||
return 0, errors.New("insert successfully but not returned id")
|
|
||||||
}
|
|
||||||
if err := rows.Scan(&id); err != nil {
|
|
||||||
return 1, err
|
|
||||||
}
|
|
||||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
session.engine.logger.Errorf("%v", err)
|
session.engine.logger.Errorf("%v", err)
|
||||||
|
|
Loading…
Reference in New Issue