diff --git a/Makefile b/Makefile index 6360b1e8..2fd48527 100644 --- a/Makefile +++ b/Makefile @@ -187,25 +187,25 @@ test-oracle: test-godror .PNONY: test-oci8 test-oci8: go-check - $(GO) test -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-oci8\#% test-oci8\#%: go-check - $(GO) test -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror test-godror: go-check - $(GO) test -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror\#% test-godror\#%: go-check - $(GO) test -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic diff --git a/core/interface.go b/core/interface.go index a5c8e4e2..b2746ae0 100644 --- a/core/interface.go +++ b/core/interface.go @@ -7,6 +7,7 @@ import ( // Queryer represents an interface to query a SQL to get data from database type Queryer interface { + QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) } diff --git a/dialects/dialect.go b/dialects/dialect.go index 53c8cf29..a1299bea 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -59,7 +59,7 @@ type Dialect interface { GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) - DropTableSQL(tableName string) (string, bool) + DropTableSQL(tableName, autoincrCol string) ([]string, bool) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) @@ -100,9 +100,9 @@ func (b *Base) FormatBytes(bs []byte) string { return fmt.Sprintf("0x%x", bs) } -func (db *Base) DropTableSQL(tableName string) (string, bool) { +func (db *Base) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { quote := db.dialect.Quoter().Quote - return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true + return []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))}, true } func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) { diff --git a/dialects/mssql.go b/dialects/mssql.go index 8ef924b8..3894e105 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -311,10 +311,10 @@ func (db *mssql) AutoIncrStr() string { return "IDENTITY" } -func (db *mssql) DropTableSQL(tableName string) (string, bool) { - return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ +func (db *mssql) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { + return []string{fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ - "DROP TABLE \"%s\"", tableName, tableName), true + "DROP TABLE \"%s\"", tableName, tableName)}, true } func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { diff --git a/dialects/oracle.go b/dialects/oracle.go index d29b1535..3e8b8630 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -523,9 +523,9 @@ func (db *oracle) SQLType(c *schemas.Column) string { case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: return schemas.Blob case schemas.Time, schemas.DateTime, schemas.TimeStamp: - res = schemas.TimeStamp + res = schemas.Date case schemas.TimeStampz: - res = "TIMESTAMP WITH TIME ZONE" + res = "TIMESTAMP" case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: res = "NUMBER" case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: @@ -556,8 +556,14 @@ func (db *oracle) IsReserved(name string) bool { return ok } -func (db *oracle) DropTableSQL(tableName string) (string, bool) { - return fmt.Sprintf("DROP TABLE `%s`", tableName), false +func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { + var sqls = []string{ + fmt.Sprintf("DROP TABLE `%s`", tableName), + } + if autoincrCol != "" { + sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName))) + } + return sqls, false } func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { @@ -589,8 +595,19 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri sql += " ), " } - sql = sql[:len(sql)-2] + ")" - return []string{sql}, false + var sqls = []string{sql[:len(sql)-2] + ")"} + + if table.AutoIncrColumn() != nil { + var sql2 = fmt.Sprintf(`CREATE sequence %s + minvalue 1 + nomaxvalue + start with 1 + increment by 1 + nocycle + nocache`, seqName(tableName)) + sqls = append(sqls, sql2) + } + return sqls, false } func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { @@ -627,12 +644,32 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table return db.HasRecords(queryer, ctx, query, args...) } -func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { - args := []interface{}{tableName} - s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + - "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" +func seqName(tableName string) string { + return "SEQ_" + strings.ToUpper(tableName) +} - rows, err := queryer.QueryContext(ctx, s, args...) +func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { + //s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + + // "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" + + s := `select column_name from user_cons_columns + where constraint_name = (select constraint_name from user_constraints + where table_name = :1 and constraint_type ='P')` + var pkName string + err := queryer.QueryRowContext(ctx, s, tableName).Scan(&pkName) + if err != nil { + return nil, nil, err + } + + s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH, + USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE, + user_col_comments.comments + FROM USER_TAB_COLS + LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME + AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME + WHERE USER_TAB_COLS.table_name = :1` + + rows, err := queryer.QueryContext(ctx, s, tableName) if err != nil { return nil, nil, err } @@ -644,11 +681,11 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam col := new(schemas.Column) col.Indexes = make(map[string]int) - var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string + var colName, colDefault, nullable, dataType, dataPrecision, dataScale, comment *string var dataLen int err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, - &dataScale, &nullable) + &dataScale, &nullable, &comment) if err != nil { return nil, nil, err } @@ -665,10 +702,28 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam col.Nullable = false } - var ignore bool + if comment != nil { + col.Comment = *comment + } + if pkName != "" && pkName == col.Name { + col.IsPrimaryKey = true - var dt string - var len1, len2 int + has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName)) + if err != nil { + return nil, nil, err + } + if has { + col.IsAutoIncrement = true + } + + fmt.Println("-----", pkName, col.Name, col.IsPrimaryKey) + } + + var ( + ignore bool + dt string + len1, len2 int + ) dts := strings.Split(*dataType, "(") dt = dts[0] if len(dts) > 1 { @@ -721,6 +776,10 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam colSeq = append(colSeq, col.Name) } + /*select * + from user_tab_comments + where Table_Name='用户表' */ + return colSeq, cols, nil } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 0b792b71..646ddcdf 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -60,8 +60,7 @@ func TestAutoTransaction(t *testing.T) { engine.Transaction(func(session *xorm.Session) (interface{}, error) { _, err := session.Insert(TestTx{Msg: "hi"}) assert.NoError(t, err) - - return nil, nil + return nil, err }) has, err := engine.Exist(&TestTx{Msg: "hi"}) diff --git a/integrations/tests.go b/integrations/tests.go index c8219935..a4105257 100644 --- a/integrations/tests.go +++ b/integrations/tests.go @@ -146,13 +146,12 @@ func createEngine(dbType, connStr string) error { if err != nil { return err } - var tableNames = make([]interface{}, 0, len(tables)) for _, table := range tables { - tableNames = append(tableNames, table.Name) - } - if err = testEngine.DropTables(tableNames...); err != nil { - return err + if err = testEngine.DropTables(table); err != nil { + return err + } } + return nil } diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 6cbbbeda..79182a63 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -59,6 +59,10 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE { + colNames = append(colNames, table.AutoIncrement) + } + if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { return "", nil, err } @@ -112,6 +116,16 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{}) return "", nil, err } + // Insert tablename (id) Values(seq_tablename.nextval) + if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE { + if _, err := buf.WriteString(","); err != nil { + return "", nil, err + } + if _, err := buf.WriteString("seq_" + tableName + ".nextval"); err != nil { + return "", nil, err + } + } + if len(exprs.Args) > 0 { if _, err := buf.WriteString(","); err != nil { return "", nil, err diff --git a/session_insert.go b/session_insert.go index 3ec4e93f..ffe47ef7 100644 --- a/session_insert.go +++ b/session_insert.go @@ -279,7 +279,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if err := session.statement.SetRefBean(bean); err != nil { return 0, err } - if len(session.statement.TableName()) <= 0 { + + var tableName = session.statement.TableName() + if tableName == "" { return 0, ErrTableNotFound } @@ -293,7 +295,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } - var tableName = session.statement.TableName() table := session.statement.RefTable colNames, args, err := session.genInsertColumns(bean) @@ -355,19 +356,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - res, err := session.queryBytes("select seq_atable.currval from dual") + var id int64 + err = session.queryRow(fmt.Sprintf("select %s.currval from dual", tableName)).Scan(&id) if err != nil { - return 0, err - } - if len(res) < 1 { - return 0, errors.New("insert no error but not returned id") - } - - idByte := res[0][table.AutoIncrement] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil || id <= 0 { return 1, err } + if id == 0 { + return 1, errors.New("insert no error but not returned id") + } aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { diff --git a/session_schema.go b/session_schema.go index 9ccf8abe..7f02e453 100644 --- a/session_schema.go +++ b/session_schema.go @@ -131,8 +131,33 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } func (session *Session) dropTable(beanOrTableName interface{}) error { - tableName := session.engine.TableName(beanOrTableName) - sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) + var tableName, autoIncrementCol string + switch beanOrTableName.(type) { + case *schemas.Table: + table := beanOrTableName.(*schemas.Table) + tableName = table.Name + if table.AutoIncrColumn() != nil { + autoIncrementCol = table.AutoIncrColumn().Name + } + case string: + tableName = beanOrTableName.(string) + default: + v := utils.ReflectValue(beanOrTableName) + table, err := session.engine.tagParser.ParseWithCache(v) + if err != nil { + return err + } + if session.statement.AltTableName != "" { + tableName = session.statement.AltTableName + } else { + tableName = table.Name + } + if table.AutoIncrColumn() != nil { + autoIncrementCol = table.AutoIncrColumn().Name + } + } + + sqlStrs, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol) if !checkIfExist { exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil { @@ -142,8 +167,11 @@ func (session *Session) dropTable(beanOrTableName interface{}) error { } if checkIfExist { - _, err := session.exec(sqlStr) - return err + for _, sqlStr := range sqlStrs { + if _, err := session.exec(sqlStr); err != nil { + return err + } + } } return nil }