From e3239710112e47c0f9a06cabd500da234efc8ba1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 21 Jul 2021 00:12:20 +0800 Subject: [PATCH] refactor some code (#2000) Reviewed-on: https://gitea.com/xorm/xorm/pulls/2000 Co-authored-by: Lunny Xiao Co-committed-by: Lunny Xiao --- dialects/driver.go | 6 +++++ dialects/mssql.go | 6 +++++ dialects/mysql.go | 6 +++++ dialects/oracle.go | 6 +++++ dialects/postgres.go | 6 +++++ dialects/sqlite3.go | 6 +++++ session_insert.go | 64 ++++++++++---------------------------------- tags/parser.go | 11 ++++---- tags/tag.go | 12 ++++++--- 9 files changed, 64 insertions(+), 59 deletions(-) diff --git a/dialects/driver.go b/dialects/driver.go index c511b665..c63dbfa3 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -18,9 +18,15 @@ type ScanContext struct { UserLocation *time.Location } +// DriverFeatures represents driver feature +type DriverFeatures struct { + SupportReturnInsertedID bool +} + // Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) + Features() *DriverFeatures GenScanResult(string) (interface{}, error) // according given column type generating a suitable scan interface Scan(*ScanContext, *core.Rows, []*sql.ColumnType, ...interface{}) error } diff --git a/dialects/mssql.go b/dialects/mssql.go index 742928b0..7deade80 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -653,6 +653,12 @@ type odbcDriver struct { baseDriver } +func (p *odbcDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + func (p *odbcDriver) Parse(driverName, dataSourceName string) (*URI, error) { var dbName string diff --git a/dialects/mysql.go b/dialects/mysql.go index 71ee3864..0ad68833 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -674,6 +674,12 @@ type mysqlDriver struct { baseDriver } +func (p *mysqlDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: true, + } +} + func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*URI, error) { dsnPattern := regexp.MustCompile( `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] diff --git a/dialects/oracle.go b/dialects/oracle.go index 902e0c66..11a6653b 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -854,6 +854,12 @@ type godrorDriver struct { baseDriver } +func (g *godrorDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + func (g *godrorDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.ORACLE} dsnPattern := regexp.MustCompile( diff --git a/dialects/postgres.go b/dialects/postgres.go index 6462982d..8a0dd7a8 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -1387,6 +1387,12 @@ func parseOpts(name string, o values) error { return nil } +func (p *pqDriver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: false, + } +} + func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { db := &URI{DBType: schemas.POSTGRES} var err error diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 89f86147..dae6bf93 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -565,6 +565,12 @@ type sqlite3Driver struct { baseDriver } +func (p *sqlite3Driver) Features() *DriverFeatures { + return &DriverFeatures{ + SupportReturnInsertedID: true, + } +} + func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*URI, error) { if strings.Contains(dataSourceName, "?") { dataSourceName = dataSourceName[:strings.Index(dataSourceName, "?")] diff --git a/session_insert.go b/session_insert.go index a9b8b7d2..f35cca53 100644 --- a/session_insert.go +++ b/session_insert.go @@ -9,7 +9,6 @@ import ( "fmt" "reflect" "sort" - "strconv" "strings" "time" @@ -334,13 +333,18 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { cleanupProcessorsClosures(&session.afterClosures) // cleanup after used } - // for postgres, many of them didn't implement lastInsertId, so we should - // implemented it ourself. - if session.engine.dialect.URI().DBType == schemas.ORACLE && len(table.AutoIncrement) > 0 { - res, err := session.queryBytes("select seq_atable.currval from dual", args...) + // if there is auto increment column and driver don't support return it + if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID { + var sql = sqlStr + if session.engine.dialect.URI().DBType == schemas.ORACLE { + sql = "select seq_atable.currval from dual" + } + + rows, err := session.queryRows(sql, args...) if err != nil { return 0, err } + defer rows.Close() defer handleAfterInsertProcessorFunc(bean) @@ -355,56 +359,16 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if len(res) < 1 { - return 0, errors.New("insert no error but not returned id") - } - - idByte := res[0][table.AutoIncrement] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil || id <= 0 { - return 1, err - } - - aiValue, err := table.AutoIncrColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } - - if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { - return 1, nil - } - - return 1, convertAssignV(aiValue.Addr(), id) - } else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES || - session.engine.dialect.URI().DBType == schemas.MSSQL) { - res, err := session.queryBytes(sqlStr, args...) - - if err != nil { - return 0, err - } - defer handleAfterInsertProcessorFunc(bean) - - session.cacheInsert(tableName) - - if table.Version != "" && session.statement.CheckVersion { - verValue, err := table.VersionColumn().ValueOf(bean) - if err != nil { - session.engine.logger.Errorf("%v", err) - } else if verValue.IsValid() && verValue.CanSet() { - session.incrVersionFieldValue(verValue) + var id int64 + if !rows.Next() { + if rows.Err() != nil { + return 0, rows.Err() } - } - - if len(res) < 1 { return 0, errors.New("insert successfully but not returned id") } - - idByte := res[0][table.AutoIncrement] - id, err := strconv.ParseInt(string(idByte), 10, 64) - if err != nil || id <= 0 { + if err := rows.Scan(&id); err != nil { return 1, err } - aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { session.engine.logger.Errorf("%v", err) diff --git a/tags/parser.go b/tags/parser.go index 72baa153..9f9a8f62 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -242,6 +242,10 @@ func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, f } func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) { + if isNotTitle(field.Name) { + return nil, ErrIgnoreField + } + var ( tag = field.Tag ormTagStr = strings.TrimSpace(tag.Get(parser.identifier)) @@ -282,12 +286,7 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { table.Name = names.GetTableName(parser.tableMapper, v) for i := 0; i < t.NumField(); i++ { - var field = t.Field(i) - if isNotTitle(field.Name) { - continue - } - - col, err := parser.parseField(table, i, field, v.Field(i)) + col, err := parser.parseField(table, i, t.Field(i), v.Field(i)) if err == ErrIgnoreField { continue } else if err != nil { diff --git a/tags/tag.go b/tags/tag.go index 641b8c52..32cdb37c 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -101,11 +101,12 @@ type Handler func(ctx *Context) error var ( // defaultTagHandlers enumerates all the default tag handler defaultTagHandlers = map[string]Handler{ + "-": IgnoreHandler, "<-": OnlyFromDBTagHandler, "->": OnlyToDBTagHandler, "PK": PKTagHandler, "NULL": NULLTagHandler, - "NOT": IgnoreTagHandler, + "NOT": NotTagHandler, "AUTOINCR": AutoIncrTagHandler, "DEFAULT": DefaultTagHandler, "CREATED": CreatedTagHandler, @@ -130,11 +131,16 @@ func init() { } } -// IgnoreTagHandler describes ignored tag handler -func IgnoreTagHandler(ctx *Context) error { +// NotTagHandler describes ignored tag handler +func NotTagHandler(ctx *Context) error { return nil } +// IgnoreHandler represetns the field should be ignored +func IgnoreHandler(ctx *Context) error { + return ErrIgnoreField +} + // OnlyFromDBTagHandler describes mapping direction tag handler func OnlyFromDBTagHandler(ctx *Context) error { ctx.col.MapType = schemas.ONLYFROMDB