diff --git a/Makefile b/Makefile index 2fd48527..2551cb52 100644 --- a/Makefile +++ b/Makefile @@ -187,25 +187,25 @@ test-oracle: test-godror .PNONY: test-oci8 test-oci8: go-check - $(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-oci8\#% test-oci8\#%: go-check - $(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror test-godror: go-check - $(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror\#% test-godror\#%: go-check - $(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic diff --git a/go.sum b/go.sum index 82d92c88..38567910 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,7 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.0.0 h1:X5PMW56eZitiTeO7tKzZxFCSpbFZJtkMMooicw2us9A= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/mattn/go-oci8 v0.0.4 h1:3h8d3VE8buPHoEcApdEoww7Gy3G0SWhwJ0UpniYxBJU= github.com/mattn/go-oci8 v0.0.4/go.mod h1:wjDx6Xm9q7dFtHJvIlrI99JytznLw5wQ4R+9mNXJwGI= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= diff --git a/oracle_test.go b/integrations/oracle_test.go similarity index 91% rename from oracle_test.go rename to integrations/oracle_test.go index ccebd965..56cd3b69 100644 --- a/oracle_test.go +++ b/integrations/oracle_test.go @@ -4,7 +4,7 @@ // +build oracle -package xorm +package integrations import ( _ "github.com/mattn/go-oci8" diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 79182a63..3ebdc07d 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -25,7 +25,7 @@ func (statement *Statement) writeInsertOutput(buf *strings.Builder, table *schem } // GenInsertSQL generates insert beans SQL -func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (string, []interface{}, error) { +func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) (*builder.BytesWriter, error) { var ( buf = builder.NewWriter() exprs = statement.ExprColumns @@ -34,29 +34,29 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) ) if _, err := buf.WriteString("INSERT INTO "); err != nil { - return "", nil, err + return nil, err } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { - return "", nil, err + return nil, err } if len(colNames) <= 0 { if statement.dialect.URI().DBType == schemas.MYSQL { if _, err := buf.WriteString(" VALUES ()"); err != nil { - return "", nil, err + return nil, err } } else { if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(" DEFAULT VALUES"); err != nil { - return "", nil, err + return nil, err } } } else { if _, err := buf.WriteString(" ("); err != nil { - return "", nil, err + return nil, err } if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE { @@ -64,94 +64,95 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) } if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(")"); err != nil { - return "", nil, err + return nil, err } if err := statement.writeInsertOutput(buf.Builder, table); err != nil { - return "", nil, err + return nil, err } if statement.Conds().IsValid() { if _, err := buf.WriteString(" SELECT "); err != nil { - return "", nil, err + return nil, err } if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err + return nil, err } if len(exprs.Args) > 0 { if _, err := buf.WriteString(","); err != nil { - return "", nil, err + return nil, err } } if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(" FROM "); err != nil { - return "", nil, err + return nil, err } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(" WHERE "); err != nil { - return "", nil, err + return nil, err } if err := statement.Conds().WriteTo(buf); err != nil { - return "", nil, err + return nil, err } } else { if _, err := buf.WriteString(" VALUES ("); err != nil { - return "", nil, err + return nil, err } if err := statement.WriteArgs(buf, args); err != nil { - return "", nil, err + return nil, err } // Insert tablename (id) Values(seq_tablename.nextval) if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE { if _, err := buf.WriteString(","); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString("seq_" + tableName + ".nextval"); err != nil { - return "", nil, err + return nil, err } } if len(exprs.Args) > 0 { if _, err := buf.WriteString(","); err != nil { - return "", nil, err + return nil, err } } if err := exprs.WriteArgs(buf); err != nil { - return "", nil, err + return nil, err } if _, err := buf.WriteString(")"); err != nil { - return "", nil, err + return nil, err } } } - if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { + if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || + statement.dialect.URI().DBType == schemas.ORACLE) { if _, err := buf.WriteString(" RETURNING "); err != nil { - return "", nil, err + return nil, err } if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { - return "", nil, err + return nil, err } } - return buf.String(), buf.Args(), nil + return buf, nil } // GenInsertMapSQL generates insert map SQL diff --git a/internal/statements/statement.go b/internal/statements/statement.go index ed7bdaeb..f64fd085 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -90,6 +90,10 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ return statement } +func (statement *Statement) Dialect() dialects.Dialect { + return statement.dialect +} + func (statement *Statement) SetTableName(tableName string) { statement.tableName = tableName } diff --git a/session_insert.go b/session_insert.go index c1c8c608..77e36bed 100644 --- a/session_insert.go +++ b/session_insert.go @@ -309,7 +309,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, err } - sqlStr, args, err := session.statement.GenInsertSQL(colNames, args) + buf, err := session.statement.GenInsertSQL(colNames, args) if err != nil { return 0, err } @@ -342,51 +342,22 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { cleanupProcessorsClosures(&session.afterClosures) // cleanup after used } + var dialect = session.statement.Dialect() + // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { - _, err := session.exec(sqlStr, args...) - if err != nil { - return 0, err - } - - defer handleAfterInsertProcessorFunc(bean) - - session.cacheInsert(tableName) - - if table.Version != "" && session.statement.CheckVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) - } - } - + if len(table.AutoIncrement) > 0 && (dialect.URI().DBType == schemas.POSTGRES || + dialect.URI().DBType == schemas.MSSQL || + dialect.URI().DBType == schemas.ORACLE) { var id int64 - err = session.queryRow(fmt.Sprintf("select seq_%s.currval from dual", tableName)).Scan(&id) - if err != nil { - return 1, err - } - if id == 0 { - return 1, errors.New("insert no error but not returned id") + if dialect.URI().DBType == schemas.ORACLE { + if _, err := buf.WriteString(" INTO :var_name"); err != nil { + return 0, err + } + buf.Append(&id) } - aiValue, err := table.AutoIncrColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } - - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return 1, nil - } - - aiValue.Set(int64ToIntValue(id, aiValue.Type())) - - return 1, nil - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || - session.engine.dialect.URI().DBType == schemas.MSSQL) { - res, err := session.queryBytes(sqlStr, args...) + res, err := session.queryBytes(buf.String(), buf.Args()...) if err != nil { return 0, err @@ -404,14 +375,16 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if len(res) < 1 { - return 0, errors.New("insert successfully but not returned id") - } + if dialect.URI().DBType != schemas.ORACLE { + if len(res) < 1 { + return 0, errors.New("insert successfully but not returned id") + } - idByte := res[0][table.AutoIncrement] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil || id <= 0 { - return 1, err + idByte := res[0][table.AutoIncrement] + id, err = strconv.ParseInt(string(idByte), 10, 64) + if err != nil || id <= 0 { + return 1, err + } } aiValue, err := table.AutoIncrColumn().ValueOf(bean) @@ -428,7 +401,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - res, err := session.exec(sqlStr, args...) + res, err := session.exec(buf.String(), buf.Args()...) if err != nil { return 0, err }