fix bugs
This commit is contained in:
parent
8b5da97cac
commit
4e99c291d8
8
Makefile
8
Makefile
|
@ -187,25 +187,25 @@ test-oracle: test-godror
|
|||
|
||||
.PNONY: test-oci8
|
||||
test-oci8: go-check
|
||||
$(GO) test -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
$(GO) test -v -race -tags=oracle -db=oci8 -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
-conn_str="$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
|
||||
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
|
||||
|
||||
.PHONY: test-oci8\#%
|
||||
test-oci8\#%: go-check
|
||||
$(GO) test -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
$(GO) test -v -race -run $* -tags=oracle -db=oci8 -schema='$(TEST_PGSQL_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
-conn_str="postgres://$(TEST_PGSQL_USERNAME):$(TEST_PGSQL_PASSWORD)@$(TEST_PGSQL_HOST)/$(TEST_PGSQL_DBNAME)" \
|
||||
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
|
||||
|
||||
.PHONY: test-godror
|
||||
test-godror: go-check
|
||||
$(GO) test -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
$(GO) test -v -race -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
|
||||
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
|
||||
|
||||
.PHONY: test-godror\#%
|
||||
test-godror\#%: go-check
|
||||
$(GO) test -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
$(GO) test -v -race -run $* -tags=oracle -db=godror -schema='$(TEST_ORACLE_SCHEMA)' -cache=$(TEST_CACHE_ENABLE) \
|
||||
-conn_str="oracle://$(TEST_ORACLE_USERNAME):$(TEST_ORACLE_PASSWORD)@$(TEST_ORACLE_HOST)/$(TEST_ORACLE_DBNAME)" \
|
||||
-coverprofile=oracle.$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
// Queryer represents an interface to query a SQL to get data from database
|
||||
type Queryer interface {
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ type Dialect interface {
|
|||
GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error)
|
||||
IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error)
|
||||
CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool)
|
||||
DropTableSQL(tableName string) (string, bool)
|
||||
DropTableSQL(tableName, autoincrCol string) ([]string, bool)
|
||||
|
||||
GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error)
|
||||
IsColumnExist(queryer core.Queryer, ctx context.Context, tableName string, colName string) (bool, error)
|
||||
|
@ -100,9 +100,9 @@ func (b *Base) FormatBytes(bs []byte) string {
|
|||
return fmt.Sprintf("0x%x", bs)
|
||||
}
|
||||
|
||||
func (db *Base) DropTableSQL(tableName string) (string, bool) {
|
||||
func (db *Base) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
|
||||
quote := db.dialect.Quoter().Quote
|
||||
return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)), true
|
||||
return []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName))}, true
|
||||
}
|
||||
|
||||
func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query string, args ...interface{}) (bool, error) {
|
||||
|
|
|
@ -311,10 +311,10 @@ func (db *mssql) AutoIncrStr() string {
|
|||
return "IDENTITY"
|
||||
}
|
||||
|
||||
func (db *mssql) DropTableSQL(tableName string) (string, bool) {
|
||||
return fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
|
||||
func (db *mssql) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
|
||||
return []string{fmt.Sprintf("IF EXISTS (SELECT * FROM sysobjects WHERE id = "+
|
||||
"object_id(N'%s') and OBJECTPROPERTY(id, N'IsUserTable') = 1) "+
|
||||
"DROP TABLE \"%s\"", tableName, tableName), true
|
||||
"DROP TABLE \"%s\"", tableName, tableName)}, true
|
||||
}
|
||||
|
||||
func (db *mssql) IndexCheckSQL(tableName, idxName string) (string, []interface{}) {
|
||||
|
|
|
@ -523,9 +523,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
|
|||
case schemas.Binary, schemas.VarBinary, schemas.Blob, schemas.TinyBlob, schemas.MediumBlob, schemas.LongBlob, schemas.Bytea:
|
||||
return schemas.Blob
|
||||
case schemas.Time, schemas.DateTime, schemas.TimeStamp:
|
||||
res = schemas.TimeStamp
|
||||
res = schemas.Date
|
||||
case schemas.TimeStampz:
|
||||
res = "TIMESTAMP WITH TIME ZONE"
|
||||
res = "TIMESTAMP"
|
||||
case schemas.Float, schemas.Double, schemas.Numeric, schemas.Decimal:
|
||||
res = "NUMBER"
|
||||
case schemas.Text, schemas.MediumText, schemas.LongText, schemas.Json:
|
||||
|
@ -556,8 +556,14 @@ func (db *oracle) IsReserved(name string) bool {
|
|||
return ok
|
||||
}
|
||||
|
||||
func (db *oracle) DropTableSQL(tableName string) (string, bool) {
|
||||
return fmt.Sprintf("DROP TABLE `%s`", tableName), false
|
||||
func (db *oracle) DropTableSQL(tableName, autoincrCol string) ([]string, bool) {
|
||||
var sqls = []string{
|
||||
fmt.Sprintf("DROP TABLE `%s`", tableName),
|
||||
}
|
||||
if autoincrCol != "" {
|
||||
sqls = append(sqls, fmt.Sprintf("DROP SEQUENCE %s", seqName(tableName)))
|
||||
}
|
||||
return sqls, false
|
||||
}
|
||||
|
||||
func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]string, bool) {
|
||||
|
@ -589,8 +595,19 @@ func (db *oracle) CreateTableSQL(table *schemas.Table, tableName string) ([]stri
|
|||
sql += " ), "
|
||||
}
|
||||
|
||||
sql = sql[:len(sql)-2] + ")"
|
||||
return []string{sql}, false
|
||||
var sqls = []string{sql[:len(sql)-2] + ")"}
|
||||
|
||||
if table.AutoIncrColumn() != nil {
|
||||
var sql2 = fmt.Sprintf(`CREATE sequence %s
|
||||
minvalue 1
|
||||
nomaxvalue
|
||||
start with 1
|
||||
increment by 1
|
||||
nocycle
|
||||
nocache`, seqName(tableName))
|
||||
sqls = append(sqls, sql2)
|
||||
}
|
||||
return sqls, false
|
||||
}
|
||||
|
||||
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
|
||||
|
@ -627,12 +644,32 @@ func (db *oracle) IsColumnExist(queryer core.Queryer, ctx context.Context, table
|
|||
return db.HasRecords(queryer, ctx, query, args...)
|
||||
}
|
||||
|
||||
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []interface{}{tableName}
|
||||
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
|
||||
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
|
||||
func seqName(tableName string) string {
|
||||
return "SEQ_" + strings.ToUpper(tableName)
|
||||
}
|
||||
|
||||
rows, err := queryer.QueryContext(ctx, s, args...)
|
||||
func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
//s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
|
||||
// "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
|
||||
|
||||
s := `select column_name from user_cons_columns
|
||||
where constraint_name = (select constraint_name from user_constraints
|
||||
where table_name = :1 and constraint_type ='P')`
|
||||
var pkName string
|
||||
err := queryer.QueryRowContext(ctx, s, tableName).Scan(&pkName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
s = `SELECT USER_TAB_COLS.COLUMN_NAME, USER_TAB_COLS.DATA_DEFAULT, USER_TAB_COLS.DATA_TYPE, USER_TAB_COLS.DATA_LENGTH,
|
||||
USER_TAB_COLS.data_precision, USER_TAB_COLS.data_scale, USER_TAB_COLS.NULLABLE,
|
||||
user_col_comments.comments
|
||||
FROM USER_TAB_COLS
|
||||
LEFT JOIN user_col_comments on user_col_comments.TABLE_NAME=USER_TAB_COLS.TABLE_NAME
|
||||
AND user_col_comments.COLUMN_NAME=USER_TAB_COLS.COLUMN_NAME
|
||||
WHERE USER_TAB_COLS.table_name = :1`
|
||||
|
||||
rows, err := queryer.QueryContext(ctx, s, tableName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -644,11 +681,11 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
col := new(schemas.Column)
|
||||
col.Indexes = make(map[string]int)
|
||||
|
||||
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
|
||||
var colName, colDefault, nullable, dataType, dataPrecision, dataScale, comment *string
|
||||
var dataLen int
|
||||
|
||||
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
|
||||
&dataScale, &nullable)
|
||||
&dataScale, &nullable, &comment)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -665,10 +702,28 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
col.Nullable = false
|
||||
}
|
||||
|
||||
var ignore bool
|
||||
if comment != nil {
|
||||
col.Comment = *comment
|
||||
}
|
||||
if pkName != "" && pkName == col.Name {
|
||||
col.IsPrimaryKey = true
|
||||
|
||||
var dt string
|
||||
var len1, len2 int
|
||||
has, err := db.HasRecords(queryer, ctx, "SELECT * FROM USER_SEQUENCES WHERE SEQUENCE_NAME = :1", seqName(tableName))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if has {
|
||||
col.IsAutoIncrement = true
|
||||
}
|
||||
|
||||
fmt.Println("-----", pkName, col.Name, col.IsPrimaryKey)
|
||||
}
|
||||
|
||||
var (
|
||||
ignore bool
|
||||
dt string
|
||||
len1, len2 int
|
||||
)
|
||||
dts := strings.Split(*dataType, "(")
|
||||
dt = dts[0]
|
||||
if len(dts) > 1 {
|
||||
|
@ -721,6 +776,10 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
|
|||
colSeq = append(colSeq, col.Name)
|
||||
}
|
||||
|
||||
/*select *
|
||||
from user_tab_comments
|
||||
where Table_Name='用户表' */
|
||||
|
||||
return colSeq, cols, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -60,8 +60,7 @@ func TestAutoTransaction(t *testing.T) {
|
|||
engine.Transaction(func(session *xorm.Session) (interface{}, error) {
|
||||
_, err := session.Insert(TestTx{Msg: "hi"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
return nil, nil
|
||||
return nil, err
|
||||
})
|
||||
|
||||
has, err := engine.Exist(&TestTx{Msg: "hi"})
|
||||
|
|
|
@ -146,13 +146,12 @@ func createEngine(dbType, connStr string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var tableNames = make([]interface{}, 0, len(tables))
|
||||
for _, table := range tables {
|
||||
tableNames = append(tableNames, table.Name)
|
||||
}
|
||||
if err = testEngine.DropTables(tableNames...); err != nil {
|
||||
return err
|
||||
if err = testEngine.DropTables(table); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,10 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE {
|
||||
colNames = append(colNames, table.AutoIncrement)
|
||||
}
|
||||
|
||||
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(colNames, exprs.ColNames...), ","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
@ -112,6 +116,16 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []interface{})
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
// Insert tablename (id) Values(seq_tablename.nextval)
|
||||
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.ORACLE {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if _, err := buf.WriteString("seq_" + tableName + ".nextval"); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs.Args) > 0 {
|
||||
if _, err := buf.WriteString(","); err != nil {
|
||||
return "", nil, err
|
||||
|
|
|
@ -279,7 +279,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
if err := session.statement.SetRefBean(bean); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(session.statement.TableName()) <= 0 {
|
||||
|
||||
var tableName = session.statement.TableName()
|
||||
if tableName == "" {
|
||||
return 0, ErrTableNotFound
|
||||
}
|
||||
|
||||
|
@ -293,7 +295,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
processor.BeforeInsert()
|
||||
}
|
||||
|
||||
var tableName = session.statement.TableName()
|
||||
table := session.statement.RefTable
|
||||
|
||||
colNames, args, err := session.genInsertColumns(bean)
|
||||
|
@ -355,19 +356,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
}
|
||||
}
|
||||
|
||||
res, err := session.queryBytes("select seq_atable.currval from dual")
|
||||
var id int64
|
||||
err = session.queryRow(fmt.Sprintf("select %s.currval from dual", tableName)).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(res) < 1 {
|
||||
return 0, errors.New("insert no error but not returned id")
|
||||
}
|
||||
|
||||
idByte := res[0][table.AutoIncrement]
|
||||
id, err := strconv.ParseInt(string(idByte), 10, 64)
|
||||
if err != nil || id <= 0 {
|
||||
return 1, err
|
||||
}
|
||||
if id == 0 {
|
||||
return 1, errors.New("insert no error but not returned id")
|
||||
}
|
||||
|
||||
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
|
||||
if err != nil {
|
||||
|
|
|
@ -131,8 +131,33 @@ func (session *Session) DropTable(beanOrTableName interface{}) error {
|
|||
}
|
||||
|
||||
func (session *Session) dropTable(beanOrTableName interface{}) error {
|
||||
tableName := session.engine.TableName(beanOrTableName)
|
||||
sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true))
|
||||
var tableName, autoIncrementCol string
|
||||
switch beanOrTableName.(type) {
|
||||
case *schemas.Table:
|
||||
table := beanOrTableName.(*schemas.Table)
|
||||
tableName = table.Name
|
||||
if table.AutoIncrColumn() != nil {
|
||||
autoIncrementCol = table.AutoIncrColumn().Name
|
||||
}
|
||||
case string:
|
||||
tableName = beanOrTableName.(string)
|
||||
default:
|
||||
v := utils.ReflectValue(beanOrTableName)
|
||||
table, err := session.engine.tagParser.ParseWithCache(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if session.statement.AltTableName != "" {
|
||||
tableName = session.statement.AltTableName
|
||||
} else {
|
||||
tableName = table.Name
|
||||
}
|
||||
if table.AutoIncrColumn() != nil {
|
||||
autoIncrementCol = table.AutoIncrColumn().Name
|
||||
}
|
||||
}
|
||||
|
||||
sqlStrs, checkIfExist := session.engine.dialect.DropTableSQL(tableName, autoIncrementCol)
|
||||
if !checkIfExist {
|
||||
exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName)
|
||||
if err != nil {
|
||||
|
@ -142,8 +167,11 @@ func (session *Session) dropTable(beanOrTableName interface{}) error {
|
|||
}
|
||||
|
||||
if checkIfExist {
|
||||
_, err := session.exec(sqlStr)
|
||||
return err
|
||||
for _, sqlStr := range sqlStrs {
|
||||
if _, err := session.exec(sqlStr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue