From d7250e866b400613fd481d226b8198b4cb05322b Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 4 Mar 2020 13:38:05 +0800 Subject: [PATCH] fix oracle --- dialects/oracle.go | 4 ++-- integrations/cache_test.go | 8 ++++---- session_insert.go | 6 +++++- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dialects/oracle.go b/dialects/oracle.go index ad7b0095..3459146a 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -725,8 +725,8 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam } func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { - args := []interface{}{} - s := "SELECT table_name FROM user_tables" + s := "SELECT table_name FROM user_tables WHERE TABLESPACE_NAME = :1 AND table_name NOT LIKE :2" + args := []interface{}{strings.ToUpper(db.uri.User), "%$%"} rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { diff --git a/integrations/cache_test.go b/integrations/cache_test.go index 44e817b1..08f2615b 100644 --- a/integrations/cache_test.go +++ b/integrations/cache_test.go @@ -62,7 +62,7 @@ func TestCacheFind(t *testing.T) { } boxes = make([]MailBox, 0, 2) - assert.NoError(t, testEngine.Alias("a").Where("a.id > -1").Asc("a.id").Find(&boxes)) + assert.NoError(t, testEngine.Alias("a").Where("`a`.`id` > -1").Asc("`a`.`id`").Find(&boxes)) assert.EqualValues(t, 2, len(boxes)) for i, box := range boxes { assert.Equal(t, inserts[i].Id, box.Id) @@ -77,7 +77,7 @@ func TestCacheFind(t *testing.T) { } boxes2 := make([]MailBox4, 0, 2) - assert.NoError(t, testEngine.Table("mail_box").Where("mail_box.id > -1").Asc("mail_box.id").Find(&boxes2)) + assert.NoError(t, testEngine.Table("mail_box").Where("`mail_box`.`id` > -1").Asc("mail_box.id").Find(&boxes2)) assert.EqualValues(t, 2, len(boxes2)) for i, box := range boxes2 { assert.Equal(t, inserts[i].Id, box.Id) @@ -164,14 +164,14 @@ func TestCacheGet(t *testing.T) { assert.NoError(t, err) var box1 MailBox3 - has, err := testEngine.Where("id = ?", inserts[0].Id).Get(&box1) + has, err := testEngine.Where("`id` = ?", inserts[0].Id).Get(&box1) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box1.Username) assert.EqualValues(t, "pass1", box1.Password) var box2 MailBox3 - has, err = testEngine.Where("id = ?", inserts[0].Id).Get(&box2) + has, err = testEngine.Where("`id` = ?", inserts[0].Id).Get(&box2) assert.NoError(t, err) assert.True(t, has) assert.EqualValues(t, "user1", box2.Username) diff --git a/session_insert.go b/session_insert.go index 5f968151..3ec4e93f 100644 --- a/session_insert.go +++ b/session_insert.go @@ -337,7 +337,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { - res, err := session.queryBytes("select seq_atable.currval from dual", args...) + _, err := session.exec(sqlStr, args...) if err != nil { return 0, err } @@ -355,6 +355,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } + res, err := session.queryBytes("select seq_atable.currval from dual") + if err != nil { + return 0, err + } if len(res) < 1 { return 0, errors.New("insert no error but not returned id") }