From fce6fbd4d5b83b4267a7b05893a09eecd68dfdf9 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 8 Mar 2020 01:06:27 +0800 Subject: [PATCH] fix bugs --- Makefile | 8 ++-- core/interface.go | 1 + dialects/dialect.go | 4 +- dialects/mssql.go | 6 +-- dialects/oracle.go | 76 ++++++++++++++++++++++++++++++------- integrations/engine_test.go | 3 +- integrations/tests.go | 2 +- session_insert.go | 5 ++- session_schema.go | 29 +++++++++++++- 9 files changed, 104 insertions(+), 30 deletions(-) diff --git a/Makefile b/Makefile index 77f0c72b..62e91439 100644 --- a/Makefile +++ b/Makefile @@ -201,25 +201,25 @@ test-oracle: test-godror .PNONY: test-oci8 test-oci8: go-check - $(GO) test -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-oci8\#% test-oci8\#%: go-check - $(GO) test -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror test-godror: go-check - $(GO) test -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic .PHONY: test-godror\#% test-godror\#%: go-check - $(GO) test -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ + $(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic diff --git a/core/interface.go b/core/interface.go index a5c8e4e2..b2746ae0 100644 --- a/core/interface.go +++ b/core/interface.go @@ -7,6 +7,7 @@ import ( // Queryer represents an interface to query a SQL to get data from database type Queryer interface { + QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) } diff --git a/dialects/dialect.go b/dialects/dialect.go index 460ab56a..d06e4f28 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -169,9 +169,9 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) { } // DropTableSQL returns drop table SQL -func (db *Base) DropTableSQL(tableName string) (string, bool) { +func (db *Base) DropTableSQL(tableName, autoincrCol string) (string, bool) { quote := db.dialect.Quoter().Quote - return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true + return []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))}, true } // HasRecords returns true if the SQL has records returned diff --git a/dialects/mssql.go b/dialects/mssql.go index cd19afb9..1a8206ff 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -421,10 +421,10 @@ func (db *mssql) AutoIncrStr() string { return "IDENTITY" } -func (db *mssql) DropTableSQL(tableName string) (string, bool) { - return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ +func (db *mssql) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { + return []string{fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ - "DROP TABLE \"%s\"", tableName, tableName), true + "DROP TABLE \"%s\"", tableName, tableName)}, true } func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string { diff --git a/dialects/oracle.go b/dialects/oracle.go index 42c5b290..6940199b 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -553,9 +553,9 @@ func (db *oracle) SQLType(c *schemas.Column) string { case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: return schemas.Blob case schemas.Time, schemas.DateTime, schemas.TimeStamp: - res = schemas.TimeStamp + res = schemas.Date case schemas.TimeStampz: - res = "TIMESTAMP WITH TIME ZONE" + res = "TIMESTAMP" case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: res = "NUMBER" case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: @@ -601,8 +601,14 @@ func (db *oracle) IsReserved(name string) bool { return ok } -func (db *oracle) DropTableSQL(tableName string) (string, bool) { - return fmt.Sprintf("DROP TABLE `%s`", tableName), false +func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) { + var sqls = []string{ + fmt.Sprintf("DROP TABLE `%s`", tableName), + } + if autoincrCol != "" { + sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName))) + } + return sqls, false } func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { @@ -672,12 +678,32 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table return db.HasRecords(queryer, ctx, query, args...) } -func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { - args := []interface{}{tableName} - s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + - "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" +func seqName(tableName string) string { + return "SEQ_" + strings.ToUpper(tableName) +} - rows, err := queryer.QueryContext(ctx, s, args...) +func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { + //s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + + // "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" + + s := `select column_name from user_cons_columns + where constraint_name = (select constraint_name from user_constraints + where table_name = :1 and constraint_type ='P')` + var pkName string + err := queryer.QueryRowContext(ctx, s, tableName).Scan(&pkName) + if err != nil { + return nil, nil, err + } + + s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH, + USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE, + user_col_comments.comments + FROM USER_TAB_COLS + LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME + AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME + WHERE USER_TAB_COLS.table_name = :1` + + rows, err := queryer.QueryContext(ctx, s, tableName) if err != nil { return nil, nil, err } @@ -689,11 +715,11 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam col := new(schemas.Column) col.Indexes = make(map[string]int) - var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string + var colName, colDefault, nullable, dataType, dataPrecision, dataScale, comment *string var dataLen int err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, - &dataScale, &nullable) + &dataScale, &nullable, &comment) if err != nil { return nil, nil, err } @@ -710,10 +736,28 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam col.Nullable = false } - var ignore bool + if comment != nil { + col.Comment = *comment + } + if pkName != "" && pkName == col.Name { + col.IsPrimaryKey = true - var dt string - var len1, len2 int + has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName)) + if err != nil { + return nil, nil, err + } + if has { + col.IsAutoIncrement = true + } + + fmt.Println("-----", pkName, col.Name, col.IsPrimaryKey) + } + + var ( + ignore bool + dt string + len1, len2 int + ) dts := strings.Split(*dataType, "(") dt = dts[0] if len(dts) > 1 { @@ -769,6 +813,10 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam return nil, nil, rows.Err() } + /*select * + from user_tab_comments + where Table_Name='用户表' */ + return colSeq, cols, nil } diff --git a/integrations/engine_test.go b/integrations/engine_test.go index 06ab4988..047cba83 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -63,8 +63,7 @@ func TestAutoTransaction(t *testing.T) { engine.Transaction(func(session *xorm.Session) (interface{}, error) { _, err := session.Insert(TestTx{Msg: "hi"}) assert.NoError(t, err) - - return nil, nil + return nil, err }) has, err := engine.Exist(&TestTx{Msg: "hi"}) diff --git a/integrations/tests.go b/integrations/tests.go index 8b14b0f4..9f01e9ae 100644 --- a/integrations/tests.go +++ b/integrations/tests.go @@ -162,7 +162,7 @@ func createEngine(dbType, connStr string) error { if err != nil { return err } - var tableNames = make([]interface{}, 0, len(tables)) + var tableNames []interface{} for _, table := range tables { tableNames = append(tableNames, table.Name) } diff --git a/session_insert.go b/session_insert.go index 43a4118b..a1bb7007 100644 --- a/session_insert.go +++ b/session_insert.go @@ -257,7 +257,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { if err := session.statement.SetRefBean(bean); err != nil { return 0, err } - if len(session.statement.TableName()) <= 0 { + + var tableName = session.statement.TableName() + if tableName == "" { return 0, ErrTableNotFound } @@ -271,7 +273,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) { processor.BeforeInsert() } - var tableName = session.statement.TableName() table := session.statement.RefTable colNames, args, err := session.genInsertColumns(bean) diff --git a/session_schema.go b/session_schema.go index e9ed9ec5..bbb5d0be 100644 --- a/session_schema.go +++ b/session_schema.go @@ -148,8 +148,33 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { } func (session *Session) dropTable(beanOrTableName interface{}) error { - tableName := session.engine.TableName(beanOrTableName) - sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) + var tableName, autoIncrementCol string + switch beanOrTableName.(type) { + case *schemas.Table: + table := beanOrTableName.(*schemas.Table) + tableName = table.Name + if table.AutoIncrColumn() != nil { + autoIncrementCol = table.AutoIncrColumn().Name + } + case string: + tableName = beanOrTableName.(string) + default: + v := utils.ReflectValue(beanOrTableName) + table, err := session.engine.tagParser.ParseWithCache(v) + if err != nil { + return err + } + if session.statement.AltTableName != "" { + tableName = session.statement.AltTableName + } else { + tableName = table.Name + } + if table.AutoIncrColumn() != nil { + autoIncrementCol = table.AutoIncrColumn().Name + } + } + + sqlStrs, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol) if !checkIfExist { exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil {