This commit is contained in:
Lunny Xiao 2020-03-08 01:06:27 +08:00
parent e840c2e456
commit fce6fbd4d5
9 changed files with 104 additions and 30 deletions

View File

@ -201,25 +201,25 @@ test-oracle: test-godror
.PNONY: test-oci8
test-oci8: go-check
$(GO) test -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-oci8\#%
test-oci8\#%: go-check
$(GO) test -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-godror
test-godror: go-check
$(GO) test -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-godror\#%
test-godror\#%: go-check
$(GO) test -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
$(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

View File

@ -7,6 +7,7 @@ import (
// Queryer represents an interface to query a SQL to get data from database
type Queryer interface {
QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row
QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
}

View File

@ -169,9 +169,9 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) {
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
func (db *Base) DropTableSQL(tableName, autoincrCol string) (string, bool) {
quote := db.dialect.Quoter().Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
return []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))}, true
}
// HasRecords returns true if the SQL has records returned

View File

@ -421,10 +421,10 @@ func (db *mssql) AutoIncrStr() string {
return "IDENTITY"
}
func (db *mssql) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
func (db *mssql) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
return []string{fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
"object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+
"DROP TABLE \"%s\"", tableName, tableName), true
"DROP TABLE \"%s\"", tableName, tableName)}, true
}
func (db *mssql) ModifyColumnSQL(tableName string, col *schemas.Column) string {

View File

@ -553,9 +553,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.TimeStamp
res = schemas.Date
case schemas.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE"
res = "TIMESTAMP"
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
res = "NUMBER"
case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
@ -601,8 +601,14 @@ func (db *oracle) IsReserved(name string) bool {
return ok
}
func (db *oracle) DropTableSQL(tableName string) (string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false
func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
var sqls = []string{
fmt.Sprintf("DROP TABLE `%s`", tableName),
}
if autoincrCol != "" {
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName)))
}
return sqls, false
}
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
@ -672,12 +678,32 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
return db.HasRecords(queryer, ctx, query, args...)
}
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
args := []interface{}{tableName}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
func seqName(tableName string) string {
return "SEQ_" + strings.ToUpper(tableName)
}
rows, err := queryer.QueryContext(ctx, s, args...)
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
//s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
// "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
s := `select column_name from user_cons_columns
where constraint_name = (select constraint_name from user_constraints
where table_name = :1 and constraint_type ='P')`
var pkName string
err := queryer.QueryRowContext(ctx, s, tableName).Scan(&pkName)
if err != nil {
return nil, nil, err
}
s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH,
USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE,
user_col_comments.comments
FROM USER_TAB_COLS
LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME
AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME
WHERE USER_TAB_COLS.table_name = :1`
rows, err := queryer.QueryContext(ctx, s, tableName)
if err != nil {
return nil, nil, err
}
@ -689,11 +715,11 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col := new(schemas.Column)
col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
var colName, colDefault, nullable, dataType, dataPrecision, dataScale, comment *string
var dataLen int
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
&dataScale, &nullable)
&dataScale, &nullable, &comment)
if err != nil {
return nil, nil, err
}
@ -710,10 +736,28 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Nullable = false
}
var ignore bool
if comment != nil {
col.Comment = *comment
}
if pkName != "" && pkName == col.Name {
col.IsPrimaryKey = true
var dt string
var len1, len2 int
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName))
if err != nil {
return nil, nil, err
}
if has {
col.IsAutoIncrement = true
}
fmt.Println("-----", pkName, col.Name, col.IsPrimaryKey)
}
var (
ignore bool
dt string
len1, len2 int
)
dts := strings.Split(*dataType, "(")
dt = dts[0]
if len(dts) > 1 {
@ -769,6 +813,10 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
return nil, nil, rows.Err()
}
/*select *
from user_tab_comments
where Table_Name='用户表' */
return colSeq, cols, nil
}

View File

@ -63,8 +63,7 @@ func TestAutoTransaction(t *testing.T) {
engine.Transaction(func(session *xorm.Session) (interface{}, error) {
_, err := session.Insert(TestTx{Msg: "hi"})
assert.NoError(t, err)
return nil, nil
return nil, err
})
has, err := engine.Exist(&TestTx{Msg: "hi"})

View File

@ -162,7 +162,7 @@ func createEngine(dbType, connStr string) error {
if err != nil {
return err
}
var tableNames = make([]interface{}, 0, len(tables))
var tableNames []interface{}
for _, table := range tables {
tableNames = append(tableNames, table.Name)
}

View File

@ -257,7 +257,9 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
var tableName = session.statement.TableName()
if tableName == "" {
return 0, ErrTableNotFound
}
@ -271,7 +273,6 @@ func (session *Session) insertStruct(bean interface{}) (int64, error) {
processor.BeforeInsert()
}
var tableName = session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)

View File

@ -148,8 +148,33 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
}
func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName)
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
var tableName, autoIncrementCol string
switch beanOrTableName.(type) {
case *schemas.Table:
table := beanOrTableName.(*schemas.Table)
tableName = table.Name
if table.AutoIncrColumn() != nil {
autoIncrementCol = table.AutoIncrColumn().Name
}
case string:
tableName = beanOrTableName.(string)
default:
v := utils.ReflectValue(beanOrTableName)
table, err := session.engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
if session.statement.AltTableName != "" {
tableName = session.statement.AltTableName
} else {
tableName = table.Name
}
if table.AutoIncrColumn() != nil {
autoIncrementCol = table.AutoIncrColumn().Name
}
}
sqlStrs, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol)
if !checkIfExist {
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
if err != nil {