diff --git a/Makefile b/Makefile index 55183557..c895c975 100644 --- a/Makefile +++ b/Makefile @@ -163,18 +163,6 @@ test-mssql\#%: go-check -do_nvarchar_override_test=$(TEST_MSSQL_DO_NVARCHAR_OVERRIDE_TEST) \ -coverprofile=mssql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -.PNONY: test-mymysql -test-mymysql: go-check - $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ - -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ - -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic -timeout=20m - -.PNONY: test-mymysql\#% -test-mymysql\#%: go-check - $(GO) test $(INTEGRATION_PACKAGES) -v -race -run $* -db=mymysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ - -conn_str="tcp:$(TEST_MYSQL_HOST)*$(TEST_MYSQL_DBNAME)/$(TEST_MYSQL_USERNAME)/$(TEST_MYSQL_PASSWORD)" \ - -coverprofile=mymysql.$(TEST_QUOTE_POLICY).$(TEST_CACHE_ENABLE).coverage.out -covermode=atomic - .PNONY: test-mysql test-mysql: go-check $(GO) test $(INTEGRATION_PACKAGES) -v -race -db=mysql -cache=$(TEST_CACHE_ENABLE) -quote=$(TEST_QUOTE_POLICY) \ diff --git a/README.md b/README.md index e50e569f..b29d1f13 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,6 @@ Drivers for Go's sql package which currently support database/sql includes: * [Mysql5.*](https://github.com/mysql/mysql-server/tree/5.7) / [Mysql8.*](https://github.com/mysql/mysql-server) / [Mariadb](https://github.com/MariaDB/server) / [Tidb](https://github.com/pingcap/tidb) - [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) - - [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * [Postgres](https://github.com/postgres/postgres) / [Cockroach](https://github.com/cockroachdb/cockroach) - [github.com/lib/pq](https://github.com/lib/pq) diff --git a/README_CN.md b/README_CN.md index 40b3a24e..7978141f 100644 --- a/README_CN.md +++ b/README_CN.md @@ -36,7 +36,6 @@ v1.0.0 相对于 v0.8.2 有以下不兼容的变更: * [Mysql5.*](https://github.com/mysql/mysql-server/tree/5.7) / [Mysql8.*](https://github.com/mysql/mysql-server) / [Mariadb](https://github.com/MariaDB/server) / [Tidb](https://github.com/pingcap/tidb) - [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) - - [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * [Postgres](https://github.com/postgres/postgres) / [Cockroach](https://github.com/cockroachdb/cockroach) - [github.com/lib/pq](https://github.com/lib/pq) diff --git a/dialects/dialect.go b/dialects/dialect.go index c34c4c96..a5a3961e 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -274,7 +274,6 @@ func regDrvsNDialects() bool { "mssql": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, "odbc": {"mssql", func() Driver { return &odbcDriver{} }, func() Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access "mysql": {"mysql", func() Driver { return &mysqlDriver{} }, func() Dialect { return &mysql{} }}, - "mymysql": {"mysql", func() Driver { return &mymysqlDriver{} }, func() Dialect { return &mysql{} }}, "postgres": {"postgres", func() Driver { return &pqDriver{} }, func() Dialect { return &postgres{} }}, "pgx": {"postgres", func() Driver { return &pqDriverPgx{} }, func() Dialect { return &postgres{} }}, "sqlite3": {"sqlite3", func() Driver { return &sqlite3Driver{} }, func() Dialect { return &sqlite3{} }}, diff --git a/dialects/mysql.go b/dialects/mysql.go index 878c34ce..fe794ffb 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -12,7 +12,6 @@ import ( "regexp" "strconv" "strings" - "time" "xorm.io/xorm/v2/core" "xorm.io/xorm/v2/schemas" @@ -792,56 +791,3 @@ func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } - -type mymysqlDriver struct { - mysqlDriver -} - -func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { - uri := &URI{DBType: schemas.MYSQL} - - pd := strings.SplitN(dataSourceName, "*", 2) - if len(pd) == 2 { - // Parse protocol part of URI - p := strings.SplitN(pd[0], ":", 2) - if len(p) != 2 { - return nil, errors.New("wrong protocol part of URI") - } - uri.Proto = p[0] - options := strings.Split(p[1], ",") - uri.Raddr = options[0] - for _, o := range options[1:] { - kv := strings.SplitN(o, "=", 2) - var k, v string - if len(kv) == 2 { - k, v = kv[0], kv[1] - } else { - k, v = o, "true" - } - switch k { - case "laddr": - uri.Laddr = v - case "timeout": - to, err := time.ParseDuration(v) - if err != nil { - return nil, err - } - uri.Timeout = to - default: - return nil, errors.New("unknown option: " + k) - } - } - // Remove protocol part - pd = pd[1:] - } - // Parse database part of URI - dup := strings.SplitN(pd[0], "/", 3) - if len(dup) != 3 { - return nil, errors.New("Wrong database part of URI") - } - uri.DBName = dup[0] - uri.User = dup[1] - uri.Passwd = dup[2] - - return uri, nil -} diff --git a/go.mod b/go.mod index 24ca9a6f..68e4f1e5 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.1 github.com/syndtr/goleveldb v1.0.0 - github.com/ziutek/mymysql v1.5.4 modernc.org/sqlite v1.20.4 xorm.io/builder v0.3.11-0.20220531020008-1bd24a7dc978 xorm.io/xorm v1.3.4 diff --git a/go.sum b/go.sum index f4f0eee4..065e63b9 100644 --- a/go.sum +++ b/go.sum @@ -179,8 +179,6 @@ github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpP github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -github.com/ziutek/mymysql v1.5.4 h1:GB0qdRGsTwQSBVYuVShFBKaXSnSnYYC2d9knnE1LHFs= -github.com/ziutek/mymysql v1.5.4/go.mod h1:LMSpPZ6DbqWFxNCHW77HeMg9I646SAhApZ/wKdgO/C0= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= diff --git a/internal/statements/statement_args.go b/internal/statements/args.go similarity index 100% rename from internal/statements/statement_args.go rename to internal/statements/args.go diff --git a/internal/statements/pagination.go b/internal/statements/pagination.go index 5f836c3b..33c001c6 100644 --- a/internal/statements/pagination.go +++ b/internal/statements/pagination.go @@ -10,11 +10,12 @@ import ( "xorm.io/builder" "xorm.io/xorm/v2/internal/utils" + "xorm.io/xorm/v2/schemas" ) func (statement *Statement) writePagination(bw *builder.BytesWriter) error { dbType := statement.dialect.URI().DBType - if dbType == "mssql" || dbType == "oracle" { + if dbType == schemas.MSSQL || dbType == schemas.ORACLE { return statement.writeOffsetFetch(bw) } return statement.writeLimitOffset(bw) @@ -50,15 +51,15 @@ func (statement *Statement) writeOffsetFetch(w builder.Writer) error { } func (statement *Statement) writeWhereWithMssqlPagination(w *builder.BytesWriter) error { - if !statement.cond.IsValid() { - return statement.writeMssqlPaginationCond(w) - } - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } - if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { - return err + if statement.cond.IsValid() { + if _, err := fmt.Fprint(w, " WHERE "); err != nil { + return err + } + if err := statement.cond.WriteTo(statement.QuoteReplacer(w)); err != nil { + return err + } } + return statement.writeMssqlPaginationCond(w) } @@ -115,15 +116,8 @@ func (statement *Statement) writeMssqlPaginationCond(w *builder.BytesWriter) err if _, err := fmt.Fprint(subWriter, "))"); err != nil { return err } - - if statement.cond.IsValid() { - if _, err := fmt.Fprint(w, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(w, " WHERE "); err != nil { - return err - } + if err := statement.writeWhereOrAnd(w, statement.cond.IsValid()); err != nil { + return err } return utils.WriteBuilder(w, subWriter) diff --git a/internal/statements/update.go b/internal/statements/update.go index 0287a368..ef9a495d 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -350,6 +350,15 @@ func (statement *Statement) writeUpdateFrom(updateWriter *builder.BytesWriter) e return err } +func (statement *Statement) writeWhereOrAnd(updateWriter *builder.BytesWriter, hasConditions bool) error { + if hasConditions { + _, err := fmt.Fprint(updateWriter, " AND ") + return err + } + _, err := fmt.Fprint(updateWriter, " WHERE ") + return err +} + func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, cond builder.Cond) error { if statement.LimitN == nil { return nil @@ -364,14 +373,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, _, err := fmt.Fprintf(updateWriter, " LIMIT %d", limitValue) return err case schemas.SQLITE: - if cond.IsValid() { - if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { - return err - } + if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil { + return err } if _, err := fmt.Fprint(updateWriter, "rowid IN (SELECT rowid FROM ", statement.quote(tableName)); err != nil { return err @@ -385,14 +388,8 @@ func (statement *Statement) writeUpdateLimit(updateWriter *builder.BytesWriter, _, err := fmt.Fprintf(updateWriter, " LIMIT %d)", limitValue) return err case schemas.POSTGRES: - if cond.IsValid() { - if _, err := fmt.Fprint(updateWriter, " AND "); err != nil { - return err - } - } else { - if _, err := fmt.Fprint(updateWriter, " WHERE "); err != nil { - return err - } + if err := statement.writeWhereOrAnd(updateWriter, cond.IsValid()); err != nil { + return err } if _, err := fmt.Fprint(updateWriter, "CTID IN (SELECT CTID FROM ", statement.quote(tableName)); err != nil { return err @@ -477,9 +474,9 @@ func (statement *Statement) writeVersionIncrSet(w builder.Writer, v reflect.Valu return nil } -func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) error { +func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSets bool) error { for i, expr := range statement.IncrColumns { - if i > 0 || hasPreviousSet { + if i > 0 || hasPreviousSets { if _, err := fmt.Fprint(w, ", "); err != nil { return err } @@ -492,10 +489,10 @@ func (statement *Statement) writeIncrSets(w builder.Writer, hasPreviousSet bool) return nil } -func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) error { +func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSets bool) error { // for update action to like "column = column - ?" for i, expr := range statement.DecrColumns { - if i > 0 || hasPreviousSet { + if i > 0 || hasPreviousSets { if _, err := fmt.Fprint(w, ", "); err != nil { return err } @@ -508,10 +505,10 @@ func (statement *Statement) writeDecrSets(w builder.Writer, hasPreviousSet bool) return nil } -func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet bool) error { +func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSets bool) error { // for update action to like "column = expression" for i, expr := range statement.ExprColumns { - if i > 0 || hasPreviousSet { + if i > 0 || hasPreviousSets { if _, err := fmt.Fprint(w, ", "); err != nil { return err } @@ -544,33 +541,51 @@ func (statement *Statement) writeExprSets(w *builder.BytesWriter, hasPreviousSet return nil } -func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { - previousLen := w.Len() - for i, colName := range colNames { - if i > 0 { - if _, err := fmt.Fprint(w, ", "); err != nil { +func (statement *Statement) writeSetColumns(colNames []string, args []interface{}) func(w *builder.BytesWriter) error { + return func(w *builder.BytesWriter) error { + if len(colNames) == 0 { + return nil + } + if len(colNames) != len(args) { + return fmt.Errorf("columns elements %d but args elements %d", len(colNames), len(args)) + } + for i, colName := range colNames { + if i > 0 { + if _, err := fmt.Fprint(w, ", "); err != nil { + return err + } + } + if _, err := fmt.Fprint(w, colName); err != nil { return err } } - if _, err := fmt.Fprint(w, colName); err != nil { - return err - } + w.Append(args...) + return nil } - w.Append(args...) +} - if err := statement.writeIncrSets(w, w.Len() > previousLen); err != nil { +func (statement *Statement) writeUpdateSets(w *builder.BytesWriter, v reflect.Value, colNames []string, args []interface{}) error { + if err := statement.writeSetColumns(colNames, args)(w); err != nil { return err } - if err := statement.writeDecrSets(w, w.Len() > previousLen); err != nil { + setNumber := len(colNames) + if err := statement.writeIncrSets(w, setNumber > 0); err != nil { return err } - if err := statement.writeExprSets(w, w.Len() > previousLen); err != nil { + setNumber += len(statement.IncrColumns) + if err := statement.writeDecrSets(w, setNumber > 0); err != nil { return err } - if err := statement.writeVersionIncrSet(w, v, w.Len() > previousLen); err != nil { + setNumber += len(statement.DecrColumns) + if err := statement.writeExprSets(w, setNumber > 0); err != nil { + return err + } + + setNumber += len(statement.ExprColumns) + if err := statement.writeVersionIncrSet(w, v, setNumber > 0); err != nil { return err } return nil diff --git a/tests/engine_test.go b/tests/engine_test.go index 282a96c2..dbe625bd 100644 --- a/tests/engine_test.go +++ b/tests/engine_test.go @@ -21,7 +21,6 @@ import ( _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" - _ "github.com/ziutek/mymysql/godrv" _ "modernc.org/sqlite" ) diff --git a/tests/session_get_test.go b/tests/session_get_test.go index 0258f55e..87e44933 100644 --- a/tests/session_get_test.go +++ b/tests/session_get_test.go @@ -185,8 +185,7 @@ func TestGetVar(t *testing.T) { assert.Equal(t, "28", valuesString["age"]) assert.Equal(t, "1.5", valuesString["money"]) - // for mymysql driver, interface{} will be []byte, so ignore it currently - if testEngine.DriverName() != "mymysql" { + { valuesInter := make(map[string]interface{}) has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter) assert.NoError(t, err) diff --git a/tests/tests.go b/tests/tests.go index a4e3646b..f513ac16 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -112,7 +112,7 @@ func createEngine(dbType, connStr string) error { testEngine, err = xorm.NewEngine(dbType, connStr) } else { testEngine, err = xorm.NewEngineGroup(dbType, strings.Split(connStr, *splitter)) - if dbType != "mysql" && dbType != "mymysql" { + if dbType != "mysql" { *ignoreSelectUpdate = true } }