diff --git a/circle.yml b/circle.yml index 3063ac9d..c5437881 100644 --- a/circle.yml +++ b/circle.yml @@ -2,7 +2,6 @@ dependencies: override: # './...' is a relative pattern which means all subdirectories - go get -t -d -v ./... - - go get -t -d -v github.com/go-xorm/tests - go get -u github.com/go-xorm/core - go get -u github.com/go-xorm/builder - go build -v @@ -21,9 +20,13 @@ database: test: override: # './...' is a relative pattern which means all subdirectories - - go test -v -race -db="sqlite3::mysql::postgres" -conn_str="./test.db::root:@/xorm_test::dbname=xorm_test sslmode=disable" -coverprofile=coverage.txt -covermode=atomic - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./sqlite3.sh - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./mysql.sh - - cd /home/ubuntu/.go_workspace/src/github.com/go-xorm/tests && ./postgres.sh + - go test -v -race -db="sqlite3" -conn_str="./test.db" -coverprofile=coverage1.txt -covermode=atomic + - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -coverprofile=coverage2.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -coverprofile=coverage3.txt -covermode=atomic + - go test -v -race -db="sqlite3" -conn_str="./test.db" -quote=2 -coverprofile=coverage4.txt -covermode=atomic + - go test -v -race -db="mysql" -conn_str="root:@/xorm_test" -quote=2 -coverprofile=coverage5.txt -covermode=atomic + - go test -v -race -db="postgres" -conn_str="dbname=xorm_test sslmode=disable" -quote=2 -coverprofile=coverage6.txt -covermode=atomic + - go get github.com/wadey/gocovmerge + - gocovmerge coverage1.txt coverage2.txt coverage3.txt coverage4.txt coverage5.txt coverage6.txt > coverage.txt post: - bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/dialect_mssql.go b/dialect_mssql.go index 6d2291dc..e1be1991 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -277,7 +277,7 @@ func (db *mssql) SupportInsertMany() bool { } func (db *mssql) IsReserved(name string) bool { - _, ok := mssqlReservedWords[name] + _, ok := mssqlReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_mysql.go b/dialect_mysql.go index 99100b23..ad41c772 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -249,7 +249,7 @@ func (db *mysql) SupportInsertMany() bool { } func (db *mysql) IsReserved(name string) bool { - _, ok := mysqlReservedWords[name] + _, ok := mysqlReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_oracle.go b/dialect_oracle.go index ac0081b3..e7f5b5f4 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -547,7 +547,7 @@ func (db *oracle) SupportInsertMany() bool { } func (db *oracle) IsReserved(name string) bool { - _, ok := oracleReservedWords[name] + _, ok := oracleReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_postgres.go b/dialect_postgres.go index 1d4daa27..8be29881 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -836,7 +836,7 @@ func (db *postgres) SupportInsertMany() bool { } func (db *postgres) IsReserved(name string) bool { - _, ok := postgresReservedWords[name] + _, ok := postgresReservedWords[strings.ToUpper(name)] return ok } diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index a55b1615..a63123b6 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -194,7 +194,7 @@ func (db *sqlite3) SupportInsertMany() bool { } func (db *sqlite3) IsReserved(name string) bool { - _, ok := sqlite3ReservedWords[name] + _, ok := sqlite3ReservedWords[strings.ToUpper(name)] return ok } diff --git a/engine.go b/engine.go index 76e59b22..674ed1ad 100644 --- a/engine.go +++ b/engine.go @@ -130,15 +130,6 @@ func (engine *Engine) SupportInsertMany() bool { return engine.dialect.SupportInsertMany() } -// QuoteStr Engine's database use which character as quote. -// mysql, sqlite use ` and postgres use " -func (engine *Engine) QuoteStr() string { - if engine.QuoteMode == QuoteNoAdd { - return "" - } - return engine.dialect.QuoteStr() -} - // Quote Use QuoteStr quote the string sql func (engine *Engine) Quote(value string) string { var buf string @@ -147,6 +138,22 @@ func (engine *Engine) Quote(value string) string { return b.String() } +func (engine *Engine) needQuote(value string) bool { + return engine.QuoteMode == QuoteAddAlways || + (engine.QuoteMode == QuoteAddReserved && engine.dialect.IsReserved(value)) +} + +func (engine *Engine) reverseQuote(value string) string { + if !engine.needQuote(value) { + return value + } + return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr() +} + +func (engine *Engine) removeQuotes(value string) string { + return strings.Replace(strings.Replace(value, "`", "", -1), engine.dialect.QuoteStr(), "", -1) +} + // QuoteTo quotes string and writes into the buffer func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { if buf == nil { @@ -159,20 +166,15 @@ func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { } v := strings.Trim(value, "`"+engine.dialect.QuoteStr()) - if engine.QuoteMode == QuoteNoAdd || - (engine.QuoteMode == QuoteAddReserved && !engine.dialect.IsReserved(v)) { + if !engine.needQuote(v) { buf.WriteString(v) return } - v = strings.Replace(v, ".", engine.QuoteStr()+"."+engine.QuoteStr(), -1) + v = strings.Replace(v, ".", engine.reverseQuote("."), -1) buf.WriteString(engine.dialect.Quote(v)) } -func (engine *Engine) quote(sql string) string { - return engine.QuoteStr() + sql + engine.QuoteStr() -} - // SqlType will be depracated, please use SQLType instead // // Deprecated: use SQLType instead diff --git a/session_insert.go b/session_insert.go index c3648171..f53eab1c 100644 --- a/session_insert.go +++ b/session_insert.go @@ -212,28 +212,16 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } cleanupProcessorsClosures(&session.beforeClosures) - var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" + var sql = "INSERT INTO %s (%v) VALUES (%v)" var statement string + var tbName = session.Engine.Quote(session.Statement.TableName()) + var quoteColNames = session.Engine.Quote(strings.Join(colNames, session.Engine.reverseQuote(", "))) if session.Engine.dialect.DBType() == core.ORACLE { - sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" - temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr()) - statement = fmt.Sprintf(sql, - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr(), - strings.Join(colMultiPlaces, temp)) + sql = "INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL" + temp := fmt.Sprintf(") INTO %s (%v) VALUES (", tbName, quoteColNames) + statement = fmt.Sprintf(sql, tbName, quoteColNames, strings.Join(colMultiPlaces, temp)) } else { - statement = fmt.Sprintf(sql, - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.QuoteStr()+", "+session.Engine.QuoteStr()), - session.Engine.QuoteStr(), - strings.Join(colMultiPlaces, "),(")) + statement = fmt.Sprintf(sql, tbName, quoteColNames, strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(statement, args...) if err != nil { @@ -349,18 +337,17 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } var sqlStr string + var tbName = session.Engine.Quote(session.Statement.TableName()) if len(colPlaces) > 0 { - sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", - session.Engine.Quote(session.Statement.TableName()), - session.Engine.QuoteStr(), - strings.Join(colNames, session.Engine.Quote(", ")), - session.Engine.QuoteStr(), + sqlStr = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", + tbName, + session.Engine.Quote(strings.Join(colNames, session.Engine.reverseQuote(", "))), colPlaces) } else { if session.Engine.dialect.DBType() == core.MYSQL { - sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.Engine.Quote(session.Statement.TableName())) + sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", tbName) } else { - sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", session.Engine.Quote(session.Statement.TableName())) + sqlStr = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES", tbName) } } diff --git a/session_update.go b/session_update.go index 792fb574..53fd7cd5 100644 --- a/session_update.go +++ b/session_update.go @@ -101,11 +101,9 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] - if strings.Contains(colName, "`") { - colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) - } else if strings.Contains(colName, session.Engine.QuoteStr()) { - colName = strings.TrimSpace(strings.Replace(colName, session.Engine.QuoteStr(), "", -1)) - } else { + colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) + colName = strings.TrimSpace(session.Engine.removeQuotes(colName)) + if colName == "" { session.Engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) return ErrCacheFailed } diff --git a/statement.go b/statement.go index 6e360bb3..c111f17c 100644 --- a/statement.go +++ b/statement.go @@ -574,16 +574,15 @@ func (statement *Statement) getExpr() map[string]exprParam { func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { newColumns := make([]string, 0) for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1) + col = statement.Engine.removeQuotes(col) ccols := strings.Split(col, ",") for _, c := range ccols { fields := strings.Split(strings.TrimSpace(c), ".") if len(fields) == 1 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])) + newColumns = append(newColumns, statement.Engine.Quote(fields[0])) } else if len(fields) == 2 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ - statement.Engine.quote(fields[1])) + newColumns = append(newColumns, statement.Engine.Quote(fields[0])+"."+ + statement.Engine.Quote(fields[1])) } else { panic(errors.New("unwanted colnames")) } @@ -620,7 +619,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { newColumns := statement.col2NewColsWithQuote(columns...) statement.ColumnStr = strings.Join(newColumns, ", ") - statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) + statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.Quote("*"), "*", -1) return statement } @@ -836,7 +835,7 @@ func (statement *Statement) genIndexSQL() []string { for idxName, index := range statement.RefTable.Indexes { if index.Type == core.IndexType { sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), - quote(tbName), quote(strings.Join(index.Cols, quote(",")))) + quote(tbName), quote(strings.Join(index.Cols, statement.Engine.reverseQuote(",")))) sqls = append(sqls, sql) } } diff --git a/statement_test.go b/statement_test.go index 32a07123..47b10044 100644 --- a/statement_test.go +++ b/statement_test.go @@ -26,7 +26,7 @@ var colStrTests = []struct { } func TestColumnsStringGeneration(t *testing.T) { - if dbType == "postgres" || dbType == "mssql" { + if dbType == "postgres" || dbType == "mssql" || *quote != int(QuoteAddAlways) { return } diff --git a/xorm_test.go b/xorm_test.go index 24062d53..998010a0 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -24,6 +24,7 @@ var ( ptrConnStr = flag.String("conn_str", "", "test database connection string") mapType = flag.String("map_type", "snake", "indicate the name mapping") cache = flag.Bool("cache", false, "if enable cache") + quote = flag.Int("quote", int(QuoteAddAlways), "quote mode") ) func createEngine(dbType, connStr string) error { @@ -36,7 +37,7 @@ func createEngine(dbType, connStr string) error { testEngine.ShowSQL(*showSQL) testEngine.logger.SetLevel(core.LOG_DEBUG) - //testEngine.QuoteMode = QuoteAddReserved + testEngine.QuoteMode = QuoteMode(*quote) } tables, err := testEngine.DBMetas()