From cf53cb80a06f129c2e69d1bed706afd190c14ade Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 27 Mar 2020 17:28:31 +0800 Subject: [PATCH] Fix insert returnning id --- Makefile | 8 +-- oracle_test.go => integrations/oracle_test.go | 2 +- internal/statements/insert.go | 51 ++++++++++--------- internal/statements/statement.go | 4 ++ session_insert.go | 4 +- 5 files changed, 37 insertions(+), 32 deletions(-) rename oracle_test.go => integrations/oracle_test.go (91%) diff --git a/Makefile b/Makefile index 62e91439..41f7bdd4 100644 --- a/Makefile +++ b/Makefile @@ -201,25 +201,25 @@ test-oracle: test-godror .PNONY: test-oci8 test-oci8: go-check - $(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-oci8\#% test-oci8\#%: go-check - $(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror test-godror: go-check - $(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror\#% test-godror\#%: go-check - $(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic diff --git a/oracle_test.go b/integrations/oracle_test.go similarity index 91% rename from oracle_test.go rename to integrations/oracle_test.go index ccebd965..56cd3b69 100644 --- a/oracle_test.go +++ b/integrations/oracle_test.go @@ -4,7 +4,7 @@ // +build oracle -package xorm +package integrations import ( _ "github.com/mattn/go-oci8" diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 91a33319..2c3fda60 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -27,7 +27,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem } // GenInsertSQL generates insert beans SQL -func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { +func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (*builder.BytesWriter, error) { var ( buf = builder.NewWriter() exprs = statement.ExprColumns @@ -36,11 +36,11 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) ) if _, err := buf.WriteString("INSERT INTO "); err != nil { - return "", nil, err + return nil, err } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { - return "", nil, err + return nil, err } var hasInsertColumns = len(colNames) > 0 @@ -58,19 +58,19 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) statement.dialect.URI().DBType != schemas.DAMENG { if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { - return "", nil, err + return nil, err } } else { if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil { - return "", nil, err + return nil, err } } } else { if _, err := buf.WriteString(" ("); err != nil { - return "", nil, err + return nil, err } if needSeq { @@ -82,19 +82,19 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if _, err := buf.WriteString(")"); err != nil { - return "", nil, err + return nil, err } if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err + return nil, err } if statement.Conds().IsValid() { if _, err := buf.WriteString(" SELECT "); err != nil { - return "", nil, err + return nil, err } if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err + return nil, err } if needSeq { @@ -109,7 +109,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { - return "", nil, err + return nil, err } if err := exprs.WriteArgs(buf); err != nil { return "", nil, err @@ -117,27 +117,27 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if _, err := buf.WriteString(" FROM "); err != nil { - return "", nil, err + return nil, err } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(" WHERE "); err != nil { - return "", nil, err + return nil, err } if err := statement.Conds().WriteTo(buf); err != nil { - return "", nil, err + return nil, err } } else { if _, err := buf.WriteString(" VALUES ("); err != nil { - return "", nil, err + return nil, err } if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err + return nil, err } // Insert tablename (id) Values(seq_tablename.nextval) @@ -154,30 +154,31 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) if len(exprs) > 0 { if _, err := buf.WriteString(","); err != nil { - return "", nil, err + return nil, err } } if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(")"); err != nil { - return "", nil, err + return nil, err } } } - if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || + statement.dialect.URI().DBType == schemas.ORACLE) { if _, err := buf.WriteString(" RETURNING "); err != nil { - return "", nil, err + return nil, err } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { - return "", nil, err + return nil, err } } - return buf.String(), buf.Args(), nil + return buf, nil } // GenInsertMapSQL generates insert map SQL diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 80451f50..b4566a42 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -91,6 +91,10 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ return statement } +func (statement *Statement) Dialect() dialects.Dialect { + return statement.dialect +} + // SetTableName set table name func (statement *Statement) SetTableName(tableName string) { statement.tableName = tableName diff --git a/session_insert.go b/session_insert.go index cb95dd7f..4540af72 100644 --- a/session_insert.go +++ b/session_insert.go @@ -280,7 +280,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { return 0, err } - sqlStr, args, err := session.statement.GenInsertSQL(colNames, args) + buf, err := session.statement.GenInsertSQL(colNames, args) if err != nil { return 0, err } @@ -384,7 +384,7 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { return 1, convert.AssignValue(*aiValue, id) } - res, err := session.exec(sqlStr, args...) + res, err := session.exec(buf.String(), buf.Args()...) if err != nil { return 0, err }