This commit is contained in:
Lunny Xiao 2020-03-08 01:06:27 +08:00
parent 8b5da97cac
commit 4e99c291d8
No known key found for this signature in database
GPG Key ID: C3B7C91B632F738A
10 changed files with 145 additions and 49 deletions

View File

@ -187,25 +187,25 @@ test-oracle: test-godror
.PNONY: test-oci8 .PNONY: test-oci8
test-oci8: go-check test-oci8: go-check
$(GO) test -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ $(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-oci8\#% .PHONY: test-oci8\#%
test-oci8\#%: go-check test-oci8\#%: go-check
$(GO) test -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ $(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \ -conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-godror .PHONY: test-godror
test-godror: go-check test-godror: go-check
$(GO) test -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ $(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
.PHONY: test-godror\#% .PHONY: test-godror\#%
test-godror\#%: go-check test-godror\#%: go-check
$(GO) test -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \ $(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \ -conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic

View File

@ -7,6 +7,7 @@ import (
// Queryer represents an interface to query a SQL to get data from database // Queryer represents an interface to query a SQL to get data from database
type Queryer interface { type Queryer interface {
QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row
QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
} }

View File

@ -59,7 +59,7 @@ type Dialect interface {
GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
DropTableSQL(tableName string) (string, bool) DropTableSQL(tableName, autoincrCol string) ([]string, bool)
GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
@ -100,9 +100,9 @@ func (b *Base) FormatBytes(bs []byte) string {
return fmt.Sprintf("0x%x", bs) return fmt.Sprintf("0x%x", bs)
} }
func (db *Base) DropTableSQL(tableName string) (string, bool) { func (db *Base) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
quote := db.dialect.Quoter().Quote quote := db.dialect.Quoter().Quote
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true return []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))}, true
} }
func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) { func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) {

View File

@ -311,10 +311,10 @@ func (db *mssql) AutoIncrStr() string {
return "IDENTITY" return "IDENTITY"
} }
func (db *mssql) DropTableSQL(tableName string) (string, bool) { func (db *mssql) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+ return []string{fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
"object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+ "object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+
"DROP TABLE \"%s\"", tableName, tableName), true "DROP TABLE \"%s\"", tableName, tableName)}, true
} }
func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {

View File

@ -523,9 +523,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea: case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
return schemas.Blob return schemas.Blob
case schemas.Time, schemas.DateTime, schemas.TimeStamp: case schemas.Time, schemas.DateTime, schemas.TimeStamp:
res = schemas.TimeStamp res = schemas.Date
case schemas.TimeStampz: case schemas.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE" res = "TIMESTAMP"
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal: case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
res = "NUMBER" res = "NUMBER"
case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json: case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
@ -556,8 +556,14 @@ func (db *oracle) IsReserved(name string) bool {
return ok return ok
} }
func (db *oracle) DropTableSQL(tableName string) (string, bool) { func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
return fmt.Sprintf("DROP TABLE `%s`", tableName), false var sqls = []string{
fmt.Sprintf("DROP TABLE `%s`", tableName),
}
if autoincrCol != "" {
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName)))
}
return sqls, false
} }
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) { func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
@ -589,8 +595,19 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri
sql += " ), " sql += " ), "
} }
sql = sql[:len(sql)-2] + ")" var sqls = []string{sql[:len(sql)-2] + ")"}
return []string{sql}, false
if table.AutoIncrColumn() != nil {
var sql2 = fmt.Sprintf(`CREATE sequence %s
minvalue 1
nomaxvalue
start with 1
increment by 1
nocycle
nocache`, seqName(tableName))
sqls = append(sqls, sql2)
}
return sqls, false
} }
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
@ -627,12 +644,32 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
return db.HasRecords(queryer, ctx, query, args...) return db.HasRecords(queryer, ctx, query, args...)
} }
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { func seqName(tableName string) string {
args := []interface{}{tableName} return "SEQ_" + strings.ToUpper(tableName)
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + }
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
rows, err := queryer.QueryContext(ctx, s, args...) func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
//s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
// "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
s := `select column_name from user_cons_columns
where constraint_name = (select constraint_name from user_constraints
where table_name = :1 and constraint_type ='P')`
var pkName string
err := queryer.QueryRowContext(ctx, s, tableName).Scan(&pkName)
if err != nil {
return nil, nil, err
}
s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH,
USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE,
user_col_comments.comments
FROM USER_TAB_COLS
LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME
AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME
WHERE USER_TAB_COLS.table_name = :1`
rows, err := queryer.QueryContext(ctx, s, tableName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -644,11 +681,11 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col := new(schemas.Column) col := new(schemas.Column)
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string var colName, colDefault, nullable, dataType, dataPrecision, dataScale, comment *string
var dataLen int var dataLen int
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
&dataScale, &nullable) &dataScale, &nullable, &comment)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -665,10 +702,28 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Nullable = false col.Nullable = false
} }
var ignore bool if comment != nil {
col.Comment = *comment
}
if pkName != "" && pkName == col.Name {
col.IsPrimaryKey = true
var dt string has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName))
var len1, len2 int if err != nil {
return nil, nil, err
}
if has {
col.IsAutoIncrement = true
}
fmt.Println("-----", pkName, col.Name, col.IsPrimaryKey)
}
var (
ignore bool
dt string
len1, len2 int
)
dts := strings.Split(*dataType, "(") dts := strings.Split(*dataType, "(")
dt = dts[0] dt = dts[0]
if len(dts) > 1 { if len(dts) > 1 {
@ -721,6 +776,10 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
/*select *
from user_tab_comments
where Table_Name='用户表' */
return colSeq, cols, nil return colSeq, cols, nil
} }

View File

@ -60,8 +60,7 @@ func TestAutoTransaction(t *testing.T) {
engine.Transaction(func(session *xorm.Session) (interface{}, error) { engine.Transaction(func(session *xorm.Session) (interface{}, error) {
_, err := session.Insert(TestTx{Msg: "hi"}) _, err := session.Insert(TestTx{Msg: "hi"})
assert.NoError(t, err) assert.NoError(t, err)
return nil, err
return nil, nil
}) })
has, err := engine.Exist(&TestTx{Msg: "hi"}) has, err := engine.Exist(&TestTx{Msg: "hi"})

View File

@ -146,13 +146,12 @@ func createEngine(dbType, connStr string) error {
if err != nil { if err != nil {
return err return err
} }
var tableNames = make([]interface{}, 0, len(tables))
for _, table := range tables { for _, table := range tables {
tableNames = append(tableNames, table.Name) if err = testEngine.DropTables(table); err != nil {
} return err
if err = testEngine.DropTables(tableNames...); err != nil { }
return err
} }
return nil return nil
} }

View File

@ -59,6 +59,10 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err return "", nil, err
} }
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE {
colNames = append(colNames, table.AutoIncrement)
}
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil { if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil {
return "", nil, err return "", nil, err
} }
@ -112,6 +116,16 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
return "", nil, err return "", nil, err
} }
// Insert tablename (id) Values(seq_tablename.nextval)
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := buf.WriteString(","); err != nil {
return "", nil, err
}
if _, err := buf.WriteString("seq_" + tableName + ".nextval"); err != nil {
return "", nil, err
}
}
if len(exprs.Args) > 0 { if len(exprs.Args) > 0 {
if _, err := buf.WriteString(","); err != nil { if _, err := buf.WriteString(","); err != nil {
return "", nil, err return "", nil, err

View File

@ -279,7 +279,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil { if err := session.statement.SetRefBean(bean); err != nil {
return 0, err return 0, err
} }
if len(session.statement.TableName()) <= 0 {
var tableName = session.statement.TableName()
if tableName == "" {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
@ -293,7 +295,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
processor.BeforeInsert() processor.BeforeInsert()
} }
var tableName = session.statement.TableName()
table := session.statement.RefTable table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean) colNames, args, err := session.genInsertColumns(bean)
@ -355,19 +356,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
} }
res, err := session.queryBytes("select seq_atable.currval from dual") var id int64
err = session.queryRow(fmt.Sprintf("select %s.currval from dual", tableName)).Scan(&id)
if err != nil { if err != nil {
return 0, err
}
if len(res) < 1 {
return 0, errors.New("insert no error but not returned id")
}
idByte := res[0][table.AutoIncrement]
id, err := strconv.ParseInt(string(idByte), 10, 64)
if err != nil || id <= 0 {
return 1, err return 1, err
} }
if id == 0 {
return 1, errors.New("insert no error but not returned id")
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil { if err != nil {

View File

@ -131,8 +131,33 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
} }
func (session *Session) dropTable(beanOrTableName interface{}) error { func (session *Session) dropTable(beanOrTableName interface{}) error {
tableName := session.engine.TableName(beanOrTableName) var tableName, autoIncrementCol string
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) switch beanOrTableName.(type) {
case *schemas.Table:
table := beanOrTableName.(*schemas.Table)
tableName = table.Name
if table.AutoIncrColumn() != nil {
autoIncrementCol = table.AutoIncrColumn().Name
}
case string:
tableName = beanOrTableName.(string)
default:
v := utils.ReflectValue(beanOrTableName)
table, err := session.engine.tagParser.ParseWithCache(v)
if err != nil {
return err
}
if session.statement.AltTableName != "" {
tableName = session.statement.AltTableName
} else {
tableName = table.Name
}
if table.AutoIncrColumn() != nil {
autoIncrementCol = table.AutoIncrColumn().Name
}
}
sqlStrs, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol)
if !checkIfExist { if !checkIfExist {
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
if err != nil { if err != nil {
@ -142,8 +167,11 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
} }
if checkIfExist { if checkIfExist {
_, err := session.exec(sqlStr) for _, sqlStr := range sqlStrs {
return err if _, err := session.exec(sqlStr); err != nil {
return err
}
}
} }
return nil return nil
} }