diff --git a/engine.go b/engine.go index 7c2c2f4f..f92302e9 100644 --- a/engine.go +++ b/engine.go @@ -124,29 +124,49 @@ func (engine *Engine) QuoteStr() string { } // Quote Use QuoteStr quote the string sql -func (engine *Engine) Quote(sql string) string { - return engine.quoteTable(sql) +func (engine *Engine) Quote(value string) string { + value = strings.TrimSpace(value) + if len(value) == 0 { + return value + } + + if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { + return value + } + + value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) + + return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr() +} + +// QuoteTo quotes string and writes into the buffer +func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { + + if buf == nil { + return + } + + value = strings.TrimSpace(value) + if value == "" { + return + } + + if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { + buf.WriteString(value) + return + } + + value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) + + buf.WriteString(engine.dialect.QuoteStr()) + buf.WriteString(value) + buf.WriteString(engine.dialect.QuoteStr()) } func (engine *Engine) quote(sql string) string { return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() } -func (engine *Engine) quoteTable(keyName string) string { - keyName = strings.TrimSpace(keyName) - if len(keyName) == 0 { - return keyName - } - - if string(keyName[0]) == engine.dialect.QuoteStr() || keyName[0] == '`' { - return keyName - } - - keyName = strings.Replace(keyName, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) - - return engine.dialect.QuoteStr() + keyName + engine.dialect.QuoteStr() -} - // SqlType will be depracated, please use SQLType instead func (engine *Engine) SqlType(c *core.Column) string { return engine.dialect.SqlType(c) diff --git a/statement.go b/statement.go index 82b86674..dc426946 100644 --- a/statement.go +++ b/statement.go @@ -985,41 +985,45 @@ func (statement *Statement) Unscoped() *Statement { } func (statement *Statement) genColumnStr() string { - table := statement.RefTable - var colNames []string - for _, col := range table.Columns() { + + var buf bytes.Buffer + + columns := statement.RefTable.Columns() + + for _, col := range columns { + if statement.OmitStr != "" { if _, ok := statement.columnMap[strings.ToLower(col.Name)]; ok { continue } } + if col.MapType == core.ONLYTODB { continue } - if statement.JoinStr != "" { - var name string - if statement.TableAlias != "" { - name = statement.Engine.Quote(statement.TableAlias) - } else { - name = statement.Engine.Quote(statement.TableName()) - } - name += "." + statement.Engine.Quote(col.Name) - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - colNames = append(colNames, "id() AS "+name) - } else { - colNames = append(colNames, name) - } - } else { - name := statement.Engine.Quote(col.Name) - if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { - colNames = append(colNames, "id() AS "+name) - } else { - colNames = append(colNames, name) - } + if buf.Len() != 0 { + buf.WriteString(", ") } + + if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { + buf.WriteString("id() AS ") + } + + if statement.JoinStr != "" { + if statement.TableAlias != "" { + buf.WriteString(statement.TableAlias) + } else { + buf.WriteString(statement.TableName()) + } + + buf.WriteString(".") + } + + statement.Engine.QuoteTo(&buf, col.Name) } - return strings.Join(colNames, ", ") + + return buf.String() } func (statement *Statement) genCreateTableSQL() string { diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 00000000..3929a76e --- /dev/null +++ b/statement_test.go @@ -0,0 +1,139 @@ +package xorm + +import ( + "reflect" + "sync" + "testing" + "time" + + "github.com/go-xorm/core" +) + +var colStrTests = []struct { + omitColumn string + onlyToDBColumnNdx int + expected string +}{ + {"", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, + {"Code2", -1, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, + {"", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`, `Longitude`"}, + {"Code3", 1, "`ID`, `Caption`, `Code1`, `Code2`, `ParentID`, `Latitude`, `Longitude`"}, + {"Longitude", 1, "`ID`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, + {"", 8, "`ID`, `IsDeleted`, `Caption`, `Code1`, `Code2`, `Code3`, `ParentID`, `Latitude`"}, +} + +// !nemec784! Only for Statement object creation +const driverName = "mysql" +const dataSourceName = "Server=TestServer;Database=TestDB;Uid=testUser;Pwd=testPassword;" + +func init() { + core.RegisterDriver(driverName, &mysqlDriver{}) +} + +func TestColumnsStringGeneration(t *testing.T) { + + var statement *Statement + + for ndx, testCase := range colStrTests { + + statement = createTestStatement() + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + } + + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped + } + + actual := statement.genColumnStr() + + if actual != testCase.expected { + t.Errorf("[test #%d] Unexpected columns string:\nwant:\t%s\nhave:\t%s", ndx, testCase.expected, actual) + } + } +} + +func BenchmarkColumnsStringGeneration(b *testing.B) { + + b.StopTimer() + + statement := createTestStatement() + + testCase := colStrTests[0] + + if testCase.omitColumn != "" { + statement.Omit(testCase.omitColumn) // !nemec784! Column must be skipped + } + + if testCase.onlyToDBColumnNdx >= 0 { + columns := statement.RefTable.Columns() + columns[testCase.onlyToDBColumnNdx].MapType = core.ONLYTODB // !nemec784! Column must be skipped + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + actual := statement.genColumnStr() + + if actual != testCase.expected { + b.Errorf("Unexpected columns string:\nwant:\t%s\nhave:\t%s", testCase.expected, actual) + } + } +} + +type TestType struct { + ID int64 `xorm:"ID PK"` + IsDeleted bool `xorm:"IsDeleted"` + Caption string `xorm:"Caption"` + Code1 string `xorm:"Code1"` + Code2 string `xorm:"Code2"` + Code3 string `xorm:"Code3"` + ParentID int64 `xorm:"ParentID"` + Latitude float64 `xorm:"Latitude"` + Longitude float64 `xorm:"Longitude"` +} + +func (TestType) TableName() string { + return "TestTable" +} + +func createTestStatement() *Statement { + + engine := createTestEngine() + + statement := &Statement{} + statement.Init() + statement.Engine = engine + statement.setRefValue(reflect.ValueOf(TestType{})) + + return statement +} + +func createTestEngine() *Engine { + driver := core.QueryDriver(driverName) + uri, err := driver.Parse(driverName, dataSourceName) + + if err != nil { + panic(err) + } + + dialect := &mysql{} + err = dialect.Init(nil, uri, driverName, dataSourceName) + + if err != nil { + panic(err) + } + + engine := &Engine{ + dialect: dialect, + Tables: make(map[reflect.Type]*core.Table), + mutex: &sync.RWMutex{}, + TagIdentifier: "xorm", + TZLocation: time.Local, + } + engine.SetMapper(core.NewCacheMapper(new(core.SnakeMapper))) + + return engine +}