From a7e010df2dd0e38ac86a587fc760858423a8f480 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 20 Jul 2021 13:46:24 +0800 Subject: [PATCH] refactor insert condition generation (#1998) Reviewed-on: https://gitea.com/xorm/xorm/pulls/1998 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- convert.go | 4 +- convert/interface.go | 1 + convert/time.go | 2 +- dialects/dialect.go | 5 +- dialects/driver.go | 11 -- dialects/mssql.go | 26 +-- dialects/mysql.go | 38 ++-- dialects/oracle.go | 35 ++-- dialects/postgres.go | 38 ++-- dialects/postgres_test.go | 2 - dialects/sqlite3.go | 26 +-- engine.go | 6 +- go.mod | 17 +- go.sum | 35 +++- integrations/engine_test.go | 1 - integrations/session_get_test.go | 10 +- integrations/session_insert_test.go | 2 +- integrations/session_update_test.go | 1 - internal/statements/statement.go | 269 ++++++++++++++-------------- internal/statements/values.go | 2 +- internal/utils/strings.go | 4 +- names/mapper.go | 2 +- rows.go | 35 +--- scan.go | 14 +- schemas/table_test.go | 2 - schemas/type.go | 1 + session.go | 5 +- session_exist.go | 5 +- session_find.go | 11 +- session_get.go | 11 +- session_insert.go | 1 - session_iterate.go | 5 +- session_query.go | 18 +- session_update.go | 6 +- tags/parser.go | 1 + 35 files changed, 324 insertions(+), 328 deletions(-) diff --git a/convert.go b/convert.go index 1aaf5dca..c3eb4de9 100644 --- a/convert.go +++ b/convert.go @@ -373,7 +373,6 @@ func asTime(src interface{}, dbLoc *time.Location, uiLoc *time.Location) (*time. case *sql.NullInt64: tm := time.Unix(t.Int64, 0).In(uiLoc) return &tm, nil - } return nil, fmt.Errorf("unsupported value %#v as time", src) } @@ -751,7 +750,6 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) { } return v, nil } - } return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) } @@ -946,8 +944,10 @@ var ( _ sql.Scanner = &EmptyScanner{} ) +// EmptyScanner represents an empty scanner which will ignore the scan type EmptyScanner struct{} +// Scan implements sql.Scanner func (EmptyScanner) Scan(value interface{}) error { return nil } diff --git a/convert/interface.go b/convert/interface.go index 2b055253..b0f28c81 100644 --- a/convert/interface.go +++ b/convert/interface.go @@ -10,6 +10,7 @@ import ( "time" ) +// Interface2Interface converts interface of pointer as interface of value func Interface2Interface(userLocation *time.Location, v interface{}) (interface{}, error) { if v == nil { return nil, nil diff --git a/convert/time.go b/convert/time.go index 283c7f83..6a53171b 100644 --- a/convert/time.go +++ b/convert/time.go @@ -45,5 +45,5 @@ func String2Time(s string, originalLocation *time.Location, convertedLocation *t return &tm, nil } } - return nil, fmt.Errorf("unsupported convertion from %s to time", s) + return nil, fmt.Errorf("unsupported conversion from %s to time", s) } diff --git a/dialects/dialect.go b/dialects/dialect.go index 81d1ee8d..fc11eac1 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -118,12 +118,9 @@ func (db *Base) HasRecords(queryer core.Queryer, ctx context.Context, query stri defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } return true, nil } - return false, nil + return false, rows.Err() } // IsColumnExist returns true if the column of the table exist diff --git a/dialects/driver.go b/dialects/driver.go index 0b6187d3..c511b665 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,14 +18,9 @@ type ScanContext struct { UserLocation *time.Location } -type DriverFeatures struct { - SupportNullable bool -} - // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) - Features() DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } @@ -82,9 +77,3 @@ type baseDriver struct{} func (b *baseDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, v ...interface{}) error { return rows.Scan(v...) } - -func (b *baseDriver) Features() DriverFeatures { - return DriverFeatures{ - SupportNullable: true, - } -} diff --git a/dialects/mssql.go b/dialects/mssql.go index 08232487..742928b0 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -264,6 +264,9 @@ func (db *mssql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version, level, edition string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -456,9 +459,6 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } var name, ctype, vdefault string var maxLen, precision, scale int var nullable, isPK, defaultIsNull, isIncrement bool @@ -512,6 +512,9 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -527,9 +530,6 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -539,6 +539,9 @@ func (db *mssql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.Name = strings.Trim(name, "` ") tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -562,11 +565,8 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, isUnique string @@ -604,6 +604,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -664,8 +667,7 @@ func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { for _, c := range kv { vv := strings.Split(strings.TrimSpace(c), "=") if len(vv) == 2 { - switch strings.ToLower(vv[0]) { - case "database": + if strings.ToLower(vv[0]) == "database" { dbName = vv[1] } } diff --git a/dialects/mysql.go b/dialects/mysql.go index 88c1038e..71ee3864 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -213,7 +213,10 @@ func (db *mysql) Version(ctx context.Context, queryer core.Queryer) (*schemas.Ve var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -254,9 +257,6 @@ func (db *mysql) SetParams(params map[string]string) { fallthrough case "COMPRESSED": db.rowFormat = t - break - default: - break } } } @@ -405,9 +405,6 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -506,6 +503,9 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -522,9 +522,6 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name, engine string var autoIncr, comment *string @@ -540,6 +537,9 @@ func (db *mysql) GetTables(queryer core.Queryer, ctx context.Context) ([]*schema table.StoreEngine = engine tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -570,11 +570,8 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, nonUnique string err = rows.Scan(&indexName, &nonUnique, &colName) @@ -586,7 +583,7 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName continue } - if "YES" == nonUnique || nonUnique == "1" { + if nonUnique == "YES" || nonUnique == "1" { indexType = schemas.IndexType } else { indexType = schemas.UniqueType @@ -610,6 +607,9 @@ func (db *mysql) GetIndexes(queryer core.Queryer, ctx context.Context, tableName } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -696,14 +696,12 @@ func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { for _, kv := range kvs { splits := strings.Split(kv, "=") if len(splits) == 2 { - switch splits[0] { - case "charset": + if splits[0] == "charset" { uri.Charset = splits[1] } } } } - } } return uri, nil @@ -720,13 +718,13 @@ func (p *mysqlDriver) GenScanResult(colType string) (interface{}, error) { case "TINYINT", "SMALLINT", "MEDIUMINT", "INT": var s sql.NullInt32 return &s, nil - case "FLOAT", "REAL", "DOUBLE PRECISION": + case "FLOAT", "REAL", "DOUBLE PRECISION", "DOUBLE": var s sql.NullFloat64 return &s, nil case "DECIMAL", "NUMERIC": var s sql.NullString return &s, nil - case "DATETIME": + case "DATETIME", "TIMESTAMP": var s sql.NullTime return &s, nil case "BIT": diff --git a/dialects/oracle.go b/dialects/oracle.go index 9240046a..902e0c66 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -525,6 +525,9 @@ func (db *oracle) Version(ctx context.Context, queryer core.Queryer) (*schemas.V var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -677,9 +680,6 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols := make(map[string]*schemas.Column) colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -759,6 +759,9 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -775,9 +778,6 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { @@ -786,6 +786,9 @@ func (db *oracle) GetTables(queryer core.Queryer, ctx context.Context) ([]*schem tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -800,11 +803,8 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } defer rows.Close() - indexes := make(map[string]*schemas.Index, 0) + indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, colName, uniqueness string @@ -838,6 +838,9 @@ func (db *oracle) GetIndexes(queryer core.Queryer, ctx context.Context, tableNam } index.AddColumn(colName) } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -851,7 +854,7 @@ type godrorDriver struct { baseDriver } -func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { +func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] @@ -863,8 +866,7 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) names := dsnPattern.SubexpNames() for i, match := range matches { - switch names[i] { - case "dbname": + if names[i] == "dbname" { db.DBName = match } } @@ -874,7 +876,7 @@ func (cfg *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) return db, nil } -func (p *godrorDriver) GenScanResult(colType string) (interface{}, error) { +func (g *godrorDriver) GenScanResult(colType string) (interface{}, error) { switch colType { case "CHAR", "NCHAR", "VARCHAR", "VARCHAR2", "NVARCHAR2", "LONG", "CLOB", "NCLOB": var s sql.NullString @@ -900,7 +902,7 @@ type oci8Driver struct { // dataSourceName=user/password@ipv4:port/dbname // dataSourceName=user/password@[ipv6]:port/dbname -func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { +func (o *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( `^(?P.*)\/(?P.*)@` + // user:password@ @@ -909,8 +911,7 @@ func (p *oci8Driver) Parse(driverName, dataSourceName string) (*URI, error) { matches := dsnPattern.FindStringSubmatch(dataSourceName) names := dsnPattern.SubexpNames() for i, match := range matches { - switch names[i] { - case "dbname": + if names[i] == "dbname" { db.DBName = match } } diff --git a/dialects/postgres.go b/dialects/postgres.go index e1dca631..6462982d 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -810,6 +810,9 @@ func (db *postgres) Version(ctx context.Context, queryer core.Queryer) (*schemas var version string if !rows.Next() { + if rows.Err() != nil { + return nil, rows.Err() + } return nil, errors.New("unknow version") } @@ -1062,7 +1065,10 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { @@ -1098,9 +1104,6 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A colSeq := make([]string, 0) for rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } col := new(schemas.Column) col.Indexes = make(map[string]int) @@ -1216,6 +1219,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A cols[col.Name] = col colSeq = append(colSeq, col.Name) } + if rows.Err() != nil { + return nil, nil, rows.Err() + } return colSeq, cols, nil } @@ -1237,9 +1243,6 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch tables := make([]*schemas.Table, 0) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } table := schemas.NewEmptyTable() var name string err = rows.Scan(&name) @@ -1249,6 +1252,9 @@ func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*sch table.Name = name tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -1279,9 +1285,6 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var indexType int var indexName, indexdef string var colNames []string @@ -1322,6 +1325,9 @@ func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableN index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -1333,12 +1339,6 @@ type pqDriver struct { baseDriver } -func (b *pqDriver) Features() DriverFeatures { - return DriverFeatures{ - SupportNullable: false, - } -} - type values map[string]string func (vs values) Set(k, v string) { @@ -1459,9 +1459,6 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri } defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return "", rows.Err() - } var defaultSchema string if err = rows.Scan(&defaultSchema); err != nil { return "", err @@ -1469,6 +1466,9 @@ func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (stri parts := strings.Split(defaultSchema, ",") return strings.TrimSpace(parts[len(parts)-1]), nil } + if rows.Err() != nil { + return "", rows.Err() + } return "", errors.New("no default schema") } diff --git a/dialects/postgres_test.go b/dialects/postgres_test.go index c0a8eb6f..e0c36f92 100644 --- a/dialects/postgres_test.go +++ b/dialects/postgres_test.go @@ -76,9 +76,7 @@ func TestParsePgx(t *testing.T) { } else if err == nil && !reflect.DeepEqual(test.expected, uri.DBName) { t.Errorf("%q got: %#v want: %#v", test.in, uri.DBName, test.expected) } - } - } func TestGetIndexColName(t *testing.T) { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index da28d9d1..89f86147 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -169,7 +169,10 @@ func (db *sqlite3) Version(ctx context.Context, queryer core.Queryer) (*schemas. var version string if !rows.Next() { - return nil, errors.New("Unknow version") + if rows.Err() != nil { + return nil, rows.Err() + } + return nil, errors.New("unknow version") } if err := rows.Scan(&version); err != nil { @@ -416,14 +419,14 @@ func (db *sqlite3) GetColumns(queryer core.Queryer, ctx context.Context, tableNa var name string if rows.Next() { - if rows.Err() != nil { - return nil, nil, rows.Err() - } err = rows.Scan(&name) if err != nil { return nil, nil, err } } + if rows.Err() != nil { + return nil, nil, rows.Err() + } if name == "" { return nil, nil, errors.New("no table named " + tableName) @@ -485,6 +488,9 @@ func (db *sqlite3) GetTables(queryer core.Queryer, ctx context.Context) ([]*sche } tables = append(tables, table) } + if rows.Err() != nil { + return nil, rows.Err() + } return tables, nil } @@ -500,9 +506,6 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa indexes := make(map[string]*schemas.Index) for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } var tmpSQL sql.NullString err = rows.Scan(&tmpSQL) if err != nil { @@ -547,6 +550,9 @@ func (db *sqlite3) GetIndexes(queryer core.Queryer, ctx context.Context, tableNa index.IsRegular = isRegular indexes[index.Name] = index } + if rows.Err() != nil { + return nil, rows.Err() + } return indexes, nil } @@ -592,9 +598,3 @@ func (p *sqlite3Driver) GenScanResult(colType string) (interface{}, error) { return &r, nil } } - -func (b *sqlite3Driver) Features() DriverFeatures { - return DriverFeatures{ - SupportNullable: false, - } -} diff --git a/engine.go b/engine.go index 35104b04..20c07e13 100644 --- a/engine.go +++ b/engine.go @@ -551,9 +551,6 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch sess := engine.NewSession() defer sess.Close() for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } _, err = io.WriteString(w, "INSERT INTO "+dstDialect.Quoter().Quote(dstTableName)+" ("+destColNames+") VALUES (") if err != nil { return err @@ -610,6 +607,9 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch return err } } + if rows.Err() != nil { + return rows.Err() + } // FIXME: Hack for postgres if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { diff --git a/go.mod b/go.mod index dbc59e76..1b3baf0c 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,17 @@ module xorm.io/xorm go 1.13 require ( - github.com/denisenkom/go-mssqldb v0.9.0 - github.com/go-sql-driver/mysql v1.5.0 - github.com/json-iterator/go v1.1.11 - github.com/lib/pq v1.7.0 - github.com/mattn/go-sqlite3 v1.14.6 + github.com/denisenkom/go-mssqldb v0.10.0 + github.com/go-sql-driver/mysql v1.6.0 github.com/goccy/go-json v0.7.4 + github.com/json-iterator/go v1.1.11 + github.com/lib/pq v1.10.2 + github.com/mattn/go-sqlite3 v1.14.8 github.com/shopspring/decimal v1.2.0 - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.7.0 github.com/syndtr/goleveldb v1.0.0 github.com/ziutek/mymysql v1.5.4 - modernc.org/sqlite v1.10.1-0.20210314190707-798bbeb9bb84 - xorm.io/builder v0.3.8 + gopkg.in/yaml.v2 v2.2.2 // indirect + modernc.org/sqlite v1.11.2 + xorm.io/builder v0.3.9 ) diff --git a/go.sum b/go.sum index da88d67a..3d4b72a6 100644 --- a/go.sum +++ b/go.sum @@ -5,12 +5,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.9.0 h1:RSohk2RsiZqLZ0zCjtfn3S4Gp4exhpBWHyQ7D0yGjAk= github.com/denisenkom/go-mssqldb v0.9.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/goccy/go-json v0.7.4 h1:B44qRUFwz/vxPKPISQ1KhvzRi9kZ28RAf6YtjriBZ5k= +github.com/goccy/go-json v0.7.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= @@ -28,12 +34,14 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/lib/pq v1.7.0 h1:h93mCPfUSkaul3Ka/VG8uZdmW1uMHDGxzu0NWHuJmHY= github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= -github.com/goccy/go-json v0.7.4 h1:B44qRUFwz/vxPKPISQ1KhvzRi9kZ28RAf6YtjriBZ5k= -github.com/goccy/go-json v0.7.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= +github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= @@ -49,10 +57,13 @@ github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 h1:OdAsTTz6O github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE= github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -103,28 +114,46 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lukechampine.com/uint128 v1.1.1 h1:pnxCASz787iMf+02ssImqk6OLt+Z5QHMoZyUXR4z6JU= +lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= modernc.org/cc/v3 v3.31.5-0.20210308123301-7a3e9dab9009 h1:u0oCo5b9wyLr++HF3AN9JicGhkUxJhMz51+8TIZH9N0= modernc.org/cc/v3 v3.31.5-0.20210308123301-7a3e9dab9009/go.mod h1:0R6jl1aZlIl2avnYfbfHBS1QB6/f+16mihBObaBC878= +modernc.org/cc/v3 v3.33.6 h1:r63dgSzVzRxUpAJFPQWHy1QeZeY1ydNENUDaBx1GqYc= +modernc.org/cc/v3 v3.33.6/go.mod h1:iPJg1pkwXqAV16SNgFBVYmggfMg6xhs+2oiO0vclK3g= modernc.org/ccgo/v3 v3.9.0 h1:JbcEIqjw4Agf+0g3Tc85YvfYqkkFOv6xBwS4zkfqSoA= modernc.org/ccgo/v3 v3.9.0/go.mod h1:nQbgkn8mwzPdp4mm6BT6+p85ugQ7FrGgIcYaE7nSrpY= +modernc.org/ccgo/v3 v3.9.5 h1:dEuUSf8WN51rDkprFuAqjfchKEzN0WttP/Py3enBwjk= +modernc.org/ccgo/v3 v3.9.5/go.mod h1:umuo2EP2oDSBnD3ckjaVUXMrmeAw8C8OSICVa0iFf60= modernc.org/httpfs v1.0.6 h1:AAgIpFZRXuYnkjftxTAZwMIiwEqAfk8aVB2/oA6nAeM= modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= modernc.org/libc v1.7.13-0.20210308123627-12f642a52bb8/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= modernc.org/libc v1.8.0 h1:Pp4uv9g0csgBMpGPABKtkieF6O5MGhfGo6ZiOdlYfR8= modernc.org/libc v1.8.0/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= +modernc.org/libc v1.9.8/go.mod h1:U1eq8YWr/Kc1RWCMFUWEdkTg8OTcfLw2kY8EDwl039w= +modernc.org/libc v1.9.11 h1:QUxZMs48Ahg2F7SN41aERvMfGLY2HU/ADnB9DC4Yts8= +modernc.org/libc v1.9.11/go.mod h1:NyF3tsA5ArIjJ83XB0JlqhjTabTCHm9aX4XMPHyQn0Q= modernc.org/mathutil v1.1.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/mathutil v1.2.2 h1:+yFk8hBprV+4c0U9GjFtL+dV3N8hOJ8JCituQcMShFY= modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/mathutil v1.4.0 h1:GCjoRaBew8ECCKINQA2nYjzvufFW9YiEuuB+rQ9bn2E= +modernc.org/mathutil v1.4.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/memory v1.0.4 h1:utMBrFcpnQDdNsmM6asmyH/FM9TqLPS7XF7otpJmrwM= modernc.org/memory v1.0.4/go.mod h1:nV2OApxradM3/OVbs2/0OsP6nPfakXpi50C7dcoHXlc= modernc.org/opt v0.1.1 h1:/0RX92k9vwVeDXj+Xn23DKp2VJubL7k8qNffND6qn3A= modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= modernc.org/sqlite v1.10.1-0.20210314190707-798bbeb9bb84 h1:rgEUzE849tFlHSoeCrKyS9cZAljC+DY7MdMHKq6R6sY= modernc.org/sqlite v1.10.1-0.20210314190707-798bbeb9bb84/go.mod h1:PGzq6qlhyYjL6uVbSgS6WoF7ZopTW/sI7+7p+mb4ZVU= +modernc.org/sqlite v1.11.2 h1:ShWQpeD3ag/bmx6TqidBlIWonWmQaSQKls3aenCbt+w= +modernc.org/sqlite v1.11.2/go.mod h1:+mhs/P1ONd+6G7hcAs6irwDi/bjTQ7nLW6LHRBsEa3A= modernc.org/strutil v1.1.0 h1:+1/yCzZxY2pZwwrsbH+4T7BQMoLQ9QiBshRC9eicYsc= modernc.org/strutil v1.1.0/go.mod h1:lstksw84oURvj9y3tn8lGvRxyRC1S2+g5uuIzNfIOBs= +modernc.org/strutil v1.1.1 h1:xv+J1BXY3Opl2ALrBwyfEikFAj8pmqcpnfmuwUwcozs= +modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= modernc.org/tcl v1.5.0 h1:euZSUNfE0Fd4W8VqXI1Ly1v7fqDJoBuAV88Ea+SnaSs= modernc.org/tcl v1.5.0/go.mod h1:gb57hj4pO8fRrK54zveIfFXBaMHK3SKJNWcmRw1cRzc= +modernc.org/tcl v1.5.5/go.mod h1:ADkaTUuwukkrlhqwERyq0SM8OvyXo7+TjFz7yAF56EI= modernc.org/token v1.0.0 h1:a0jaWiNMDhDUtqOj09wvjWWAqd3q7WpBulmL9H2egsk= modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/z v1.0.1-0.20210308123920-1f282aa71362/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= @@ -132,3 +161,5 @@ modernc.org/z v1.0.1 h1:WyIDpEpAIx4Hel6q/Pcgj/VhaQV5XPJ2I6ryIYbjnpc= modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= xorm.io/builder v0.3.8 h1:P/wPgRqa9kX5uE0aA1/ukJ23u9KH0aSRpHLwDKXigSE= xorm.io/builder v0.3.8/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= +xorm.io/builder v0.3.9 h1:Sd65/LdWyO7LR8+Cbd+e7mm3sK/7U9k0jS3999IDHMc= +xorm.io/builder v0.3.9/go.mod h1:aUW0S9eb9VCaPohFCH3j7czOx1PMW3i1HrSzbLYGBSE= diff --git a/integrations/engine_test.go b/integrations/engine_test.go index a594ee46..b5ecb2c2 100644 --- a/integrations/engine_test.go +++ b/integrations/engine_test.go @@ -172,7 +172,6 @@ func TestDumpTables(t *testing.T) { name := fmt.Sprintf("dump_%v-table.sql", tp) t.Run(name, func(t *testing.T) { assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, name, tp)) - }) } diff --git a/integrations/session_get_test.go b/integrations/session_get_test.go index b1dffe14..d3ce2a11 100644 --- a/integrations/session_get_test.go +++ b/integrations/session_get_test.go @@ -818,8 +818,9 @@ func TestGetBigFloat(t *testing.T) { } type GetBigFloat2 struct { - Id int64 - Money *big.Float `xorm:"decimal(22,2)"` + Id int64 + Money *big.Float `xorm:"decimal(22,2)"` + Money2 big.Float `xorm:"decimal(22,2)"` } assert.NoError(t, PrepareEngine()) @@ -827,7 +828,8 @@ func TestGetBigFloat(t *testing.T) { { var gf2 = GetBigFloat2{ - Money: big.NewFloat(9999999.99), + Money: big.NewFloat(9999999.99), + Money2: *big.NewFloat(99.99), } _, err := testEngine.Insert(&gf2) assert.NoError(t, err) @@ -845,12 +847,14 @@ func TestGetBigFloat(t *testing.T) { assert.NoError(t, err) assert.True(t, has) assert.True(t, gf3.Money.String() == gf2.Money.String(), "%v != %v", gf3.Money.String(), gf2.Money.String()) + assert.True(t, gf3.Money2.String() == gf2.Money2.String(), "%v != %v", gf3.Money2.String(), gf2.Money2.String()) var gfs []GetBigFloat2 err = testEngine.Find(&gfs) assert.NoError(t, err) assert.EqualValues(t, 1, len(gfs)) assert.True(t, gfs[0].Money.String() == gf2.Money.String(), "%v != %v", gfs[0].Money.String(), gf2.Money.String()) + assert.True(t, gfs[0].Money2.String() == gf2.Money2.String(), "%v != %v", gfs[0].Money2.String(), gf2.Money2.String()) } } diff --git a/integrations/session_insert_test.go b/integrations/session_insert_test.go index a023ab72..ce52d3c4 100644 --- a/integrations/session_insert_test.go +++ b/integrations/session_insert_test.go @@ -202,7 +202,7 @@ func TestInsertDefault2(t *testing.T) { Id int64 Name string Url string `xorm:"text"` - CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00' TIMESTAMP"` + CheckTime time.Time `xorm:"not null default '2000-01-01 00:00:00'"` } di := new(DefaultInsert2) diff --git a/integrations/session_update_test.go b/integrations/session_update_test.go index 22808d60..cc1042b6 100644 --- a/integrations/session_update_test.go +++ b/integrations/session_update_test.go @@ -1313,7 +1313,6 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) { assert.EqualValues(t, true, has) assert.EqualValues(t, "", record.OnlyFromDBField) return &record - } assert.NoError(t, PrepareEngine()) assertSync(t, new(TestOnlyFromDBField)) diff --git a/internal/statements/statement.go b/internal/statements/statement.go index bfe9987f..0e245a96 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -8,6 +8,7 @@ import ( "database/sql/driver" "errors" "fmt" + "math/big" "reflect" "strings" "time" @@ -662,10 +663,6 @@ func (statement *Statement) GenIndexSQL() []string { return sqls } -func uniqueName(tableName, uqeName string) string { - return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) -} - // GenUniqueSQL generates unique SQL func (statement *Statement) GenUniqueSQL() []string { var sqls []string @@ -693,6 +690,138 @@ func (statement *Statement) GenDelIndexSQL() []string { return sqls } +func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect.Type, col *schemas.Column, allUseBool, requiredField bool) (interface{}, bool, error) { + switch fieldType.Kind() { + case reflect.Ptr: + if fieldValue.IsNil() { + return nil, true, nil + } + return statement.asDBCond(fieldValue.Elem(), fieldType.Elem(), col, allUseBool, requiredField) + case reflect.Bool: + if allUseBool || requiredField { + return fieldValue.Interface(), true, nil + } + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + return nil, false, nil + case reflect.String: + if !requiredField && fieldValue.String() == "" { + return nil, false, nil + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + return fieldValue.String(), true, nil + } + return fieldValue.Interface(), true, nil + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + return nil, false, nil + } + return fieldValue.Interface(), true, nil + case reflect.Struct: + if fieldType.ConvertibleTo(schemas.TimeType) { + t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + return nil, false, nil + } + return dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t), true, nil + } else if fieldType.ConvertibleTo(schemas.BigFloatType) { + t := fieldValue.Convert(schemas.BigFloatType).Interface().(big.Float) + v := t.String() + if v == "0" { + return nil, false, nil + } + return t.String(), true, nil + } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { + return nil, false, nil + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ := valNul.Value() + if val == nil && !requiredField { + return nil, false, nil + } + return val, true, nil + } else { + if col.IsJSON { + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + } else { + table, err := statement.tagParser.ParseWithCache(fieldValue) + if err != nil { + return fieldValue.Interface(), true, nil + } + + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { + return pkField.Interface(), true, nil + } + return nil, false, nil + } + return nil, false, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) + } + } + case reflect.Array: + return nil, false, nil + case reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + return nil, false, nil + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + return nil, false, nil + } + + if col.SQLType.IsText() { + bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return string(bytes), true, nil + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + return fieldValue.Bytes(), true, nil + } + return nil, false, nil + } + bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) + if err != nil { + return nil, false, err + } + return bytes, true, nil + } + return nil, false, nil + } + return fieldValue.Interface(), true, nil +} + func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, @@ -747,9 +876,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, continue } - fieldType := reflect.TypeOf(fieldValue.Interface()) requiredField := useAllCols - if b, ok := getFlagForColumn(mustColumnMap, col); ok { if b { requiredField = true @@ -758,6 +885,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } + fieldType := reflect.TypeOf(fieldValue.Interface()) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { @@ -774,131 +902,12 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, } } - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool || requiredField { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if !requiredField && fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if !requiredField && fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if !requiredField && fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if !requiredField && fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType.ConvertibleTo(schemas.TimeType) { - t := fieldValue.Convert(schemas.TimeType).Interface().(time.Time) - if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { - continue - } - val = dialects.FormatColumnTime(statement.dialect, statement.defaultTimeZone, col, t) - } else if _, ok := reflect.New(fieldType).Interface().(convert.Conversion); ok { - continue - } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { - val, _ = valNul.Value() - if val == nil && !requiredField { - continue - } - } else { - if col.IsJSON { - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - table, err := statement.tagParser.ParseWithCache(fieldValue) - if err != nil { - val = fieldValue.Interface() - } else { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { - val = pkField.Interface() - } else { - continue - } - } else { - //TODO: how to handler? - return nil, fmt.Errorf("not supported %v as %v", fieldValue.Interface(), table.PrimaryKeys) - } - } - } - } - case reflect.Array: + val, ok, err := statement.asDBCond(fieldValue, fieldType, col, allUseBool, requiredField) + if err != nil { + return nil, err + } + if !ok { continue - case reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) - if err != nil { - return nil, err - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() } conds = append(conds, builder.Eq{colName: val}) diff --git a/internal/statements/values.go b/internal/statements/values.go index ee3821e9..c572ead5 100644 --- a/internal/statements/values.go +++ b/internal/statements/values.go @@ -23,7 +23,7 @@ var ( bigFloatType = reflect.TypeOf(big.Float{}) ) -// Value2Interface convert a field value of a struct to interface for puting into database +// Value2Interface convert a field value of a struct to interface for putting into database func (statement *Statement) Value2Interface(col *schemas.Column, fieldValue reflect.Value) (interface{}, error) { if fieldValue.CanAddr() { if fieldConvert, ok := fieldValue.Addr().Interface().(convert.Conversion); ok { diff --git a/internal/utils/strings.go b/internal/utils/strings.go index 86469c0f..159e2876 100644 --- a/internal/utils/strings.go +++ b/internal/utils/strings.go @@ -13,7 +13,7 @@ func IndexNoCase(s, sep string) int { return strings.Index(strings.ToLower(s), strings.ToLower(sep)) } -// SplitNoCase split a string by a seperator with no care of capitalize +// SplitNoCase split a string by a separator with no care of capitalize func SplitNoCase(s, sep string) []string { idx := IndexNoCase(s, sep) if idx < 0 { @@ -22,7 +22,7 @@ func SplitNoCase(s, sep string) []string { return strings.Split(s, s[idx:idx+len(sep)]) } -// SplitNNoCase split n by a seperator with no care of capitalize +// SplitNNoCase split n by a separator with no care of capitalize func SplitNNoCase(s, sep string, n int) []string { idx := IndexNoCase(s, sep) if idx < 0 { diff --git a/names/mapper.go b/names/mapper.go index b0ce8076..69f67171 100644 --- a/names/mapper.go +++ b/names/mapper.go @@ -79,7 +79,7 @@ func (m SameMapper) Table2Obj(t string) string { return t } -// SnakeMapper implements IMapper and provides name transaltion between +// SnakeMapper implements IMapper and provides name translation between // struct and database table type SnakeMapper struct { } diff --git a/rows.go b/rows.go index 5e0a1ffe..8e7cc075 100644 --- a/rows.go +++ b/rows.go @@ -5,7 +5,6 @@ package xorm import ( - "database/sql" "errors" "fmt" "reflect" @@ -17,10 +16,9 @@ import ( // Rows rows wrapper a rows to type Rows struct { - session *Session - rows *core.Rows - beanType reflect.Type - lastError error + session *Session + rows *core.Rows + beanType reflect.Type } func newRows(session *Session, bean interface{}) (*Rows, error) { @@ -62,15 +60,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://gitea.com/xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.statement.GetUnscoped() { // tag "deleted" is enabled - var colName = session.engine.Quote(col.Name) - if addedTableName { - var nm = session.statement.TableName() - if len(session.statement.TableAlias) > 0 { - nm = session.statement.TableAlias - } - colName = session.engine.Quote(nm) + "." + colName - } - autoCond = session.statement.CondDeleted(col) } } @@ -86,7 +75,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { rows.rows, err = rows.session.queryRows(sqlStr, args...) if err != nil { - rows.lastError = err rows.Close() return nil, err } @@ -96,25 +84,18 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { // Next move cursor to next record, return false if end has reached func (rows *Rows) Next() bool { - if rows.lastError == nil && rows.rows != nil { - hasNext := rows.rows.Next() - if !hasNext { - rows.lastError = sql.ErrNoRows - } - return hasNext - } - return false + return rows.rows.Next() } // Err returns the error, if any, that was encountered during iteration. Err may be called after an explicit or implicit Close. func (rows *Rows) Err() error { - return rows.lastError + return rows.rows.Err() } // Scan row record to bean properties func (rows *Rows) Scan(bean interface{}) error { - if rows.lastError != nil { - return rows.lastError + if rows.Err() != nil { + return rows.Err() } if reflect.Indirect(reflect.ValueOf(bean)).Type() != rows.beanType { @@ -158,5 +139,5 @@ func (rows *Rows) Close() error { return rows.rows.Close() } - return rows.lastError + return rows.Err() } diff --git a/scan.go b/scan.go index 444aa8ac..ccd6938d 100644 --- a/scan.go +++ b/scan.go @@ -211,12 +211,8 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column scanResult = &sql.RawBytes{} replaced = true default: - var useNullable = true - if engine.driver.Features().SupportNullable { - nullable, ok := types[0].Nullable() - useNullable = ok && nullable - } - if useNullable { + nullable, ok := types[0].Nullable() + if !ok || nullable { scanResult, replaced, err = genScanResultsByBeanNullable(v) } else { scanResult, replaced, err = genScanResultsByBean(v) @@ -286,15 +282,15 @@ func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) { return nil, err } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := row2mapBytes(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } diff --git a/schemas/table_test.go b/schemas/table_test.go index 0e35193f..f352675b 100644 --- a/schemas/table_test.go +++ b/schemas/table_test.go @@ -58,7 +58,6 @@ func TestGetColumnIdx(t *testing.T) { func BenchmarkGetColumnWithToLower(b *testing.B) { for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { - if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok { b.Errorf("Column not found:%s", test.name) } @@ -69,7 +68,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) { func BenchmarkGetColumnIdxWithToLower(b *testing.B) { for i := 0; i < b.N; i++ { for _, test := range testsGetColumn { - if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok { if test.idx < len(c) { continue diff --git a/schemas/type.go b/schemas/type.go index 62e66c2e..d64251bf 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -65,6 +65,7 @@ func (s *SQLType) IsTime() bool { return s.IsType(TIME_TYPE) } +// IsBool returns true if column is a boolean type func (s *SQLType) IsBool() bool { return s.IsType(BOOL_TYPE) } diff --git a/session.go b/session.go index 8c1d8c3b..62d6a770 100644 --- a/session.go +++ b/session.go @@ -391,9 +391,6 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq table *schemas.Table, newElemFunc func([]string) reflect.Value, sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var newValue = newElemFunc(fields) bean := newValue.Interface() dataStruct := newValue.Elem() @@ -415,7 +412,7 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sq bean: bean, }) } - return nil + return rows.Err() } func (session *Session) row2Slice(rows *core.Rows, fields []string, types []*sql.ColumnType, bean interface{}) ([]interface{}, error) { diff --git a/session_exist.go b/session_exist.go index e52c618e..b5e4a655 100644 --- a/session_exist.go +++ b/session_exist.go @@ -25,5 +25,8 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } defer rows.Close() - return rows.Next(), nil + if rows.Next() { + return true, nil + } + return false, rows.Err() } diff --git a/session_find.go b/session_find.go index 89e34e80..010ecd6c 100644 --- a/session_find.go +++ b/session_find.go @@ -255,9 +255,6 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect } for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var newValue = newElemFunc(fields) bean := newValue.Interface() @@ -278,7 +275,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect return err } } - return nil + return rows.Err() } func convertPKToValue(table *schemas.Table, dst interface{}, pk schemas.PK) error { @@ -325,9 +322,6 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in var i int ids = make([]schemas.PK, 0) for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } i++ if i > 500 { session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") @@ -348,6 +342,9 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) diff --git a/session_get.go b/session_get.go index 1062bd9d..08172524 100644 --- a/session_get.go +++ b/session_get.go @@ -159,10 +159,7 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *schemas.Table, defer rows.Close() if !rows.Next() { - if rows.Err() != nil { - return false, rows.Err() - } - return false, nil + return false, rows.Err() } // WARN: Alougth rows return true, but we may also return error. @@ -313,14 +310,14 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf defer rows.Close() if rows.Next() { - if rows.Err() != nil { - return true, rows.Err() - } err = rows.ScanSlice(&res) if err != nil { return true, err } } else { + if rows.Err() != nil { + return false, rows.Err() + } return false, ErrCacheFailed } diff --git a/session_insert.go b/session_insert.go index b41dbbac..a9b8b7d2 100644 --- a/session_insert.go +++ b/session_insert.go @@ -325,7 +325,6 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { copy(afterClosures, session.afterClosures) session.afterInsertBeans[bean] = &afterClosures } - } else { if _, ok := interface{}(bean).(AfterInsertProcessor); ok { session.afterInsertBeans[bean] = nil diff --git a/session_iterate.go b/session_iterate.go index dbbeb3f4..f6301009 100644 --- a/session_iterate.go +++ b/session_iterate.go @@ -43,9 +43,6 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { i := 0 for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } b := reflect.New(rows.beanType).Interface() err = rows.Scan(b) if err != nil { @@ -57,7 +54,7 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { } i++ } - return err + return rows.Err() } // BufferSize sets the buffersize for iterate diff --git a/session_query.go b/session_query.go index 8543ba12..a4070985 100644 --- a/session_query.go +++ b/session_query.go @@ -33,15 +33,15 @@ func (session *Session) rows2Strings(rows *core.Rows) (resultsSlice []map[string } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := session.engine.row2mapStr(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } @@ -57,15 +57,15 @@ func (session *Session) rows2SliceString(rows *core.Rows) (resultsSlice [][]stri } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } record, err := session.engine.row2sliceStr(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, record) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } @@ -120,15 +120,15 @@ func (session *Session) rows2Interfaces(rows *core.Rows) (resultsSlice []map[str return nil, err } for rows.Next() { - if rows.Err() != nil { - return nil, rows.Err() - } result, err := session.engine.row2mapInterface(rows, types, fields) if err != nil { return nil, err } resultsSlice = append(resultsSlice, result) } + if rows.Err() != nil { + return nil, rows.Err() + } return resultsSlice, nil } diff --git a/session_update.go b/session_update.go index 32e28ae0..4f8e6961 100644 --- a/session_update.go +++ b/session_update.go @@ -59,9 +59,6 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = make([]schemas.PK, 0) for rows.Next() { - if rows.Err() != nil { - return rows.Err() - } var res = make([]string, len(table.PrimaryKeys)) err = rows.ScanSlice(&res) if err != nil { @@ -84,6 +81,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = append(ids, pk) } + if rows.Err() != nil { + return rows.Err() + } session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) diff --git a/tags/parser.go b/tags/parser.go index b793a8f1..72baa153 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -124,6 +124,7 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index } } +// ErrIgnoreField represents an error to ignore field var ErrIgnoreField = errors.New("field will be ignored") func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {