From 41388c2f56de7f3f74d56e218eaed30b0945313b Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 29 Feb 2020 08:59:59 +0000 Subject: [PATCH] Use a new ContextLogger interface to implement logger (#1557) Fix bug Add log track on prepare & tx Some improvements remove unused codes refactor logger Fix bug log context add ContextLogger interface Reviewed-on: https://gitea.com/xorm/xorm/pulls/1557 --- core/db.go | 46 +++++++++++- core/stmt.go | 65 ++++++++++++++++- core/tx.go | 87 ++++++++++++++++++++--- dialects/dialect.go | 46 +++--------- dialects/mssql.go | 3 - dialects/mysql.go | 19 ++--- dialects/oracle.go | 4 -- dialects/postgres.go | 6 -- dialects/sqlite3.go | 12 +--- engine.go | 114 ++++++++++++------------------ engine_group.go | 10 +-- interface.go | 7 +- internal/statements/statement.go | 18 ++--- internal/statements/update.go | 2 +- log/logger_context.go | 108 ++++++++++++++++++++++++++++ schemas/table.go | 8 +-- schemas/type.go | 12 ++-- session.go | 25 ++----- session_convert.go | 18 ++--- session_delete.go | 4 +- session_find.go | 24 ++++--- session_get.go | 14 ++-- session_insert.go | 14 ++-- session_raw.go | 45 ++---------- session_schema.go | 2 +- session_tx.go | 116 ++++++++++++++++++++----------- session_update.go | 20 +++--- tags/parser.go | 79 +++++++++++++++------ tags/parser_test.go | 44 ++++++++++++ tags/tag.go | 2 +- tags_test.go | 2 +- xorm.go | 2 +- xorm_test.go | 2 +- 33 files changed, 617 insertions(+), 363 deletions(-) create mode 100644 log/logger_context.go create mode 100644 tags/parser_test.go diff --git a/core/db.go b/core/db.go index 8f16e848..592ccf18 100644 --- a/core/db.go +++ b/core/db.go @@ -12,7 +12,9 @@ import ( "reflect" "regexp" "sync" + "time" + "xorm.io/xorm/log" "xorm.io/xorm/names" ) @@ -81,6 +83,7 @@ type DB struct { Mapper names.Mapper reflectCache map[reflect.Type]*cacheStruct reflectCacheMutex sync.RWMutex + Logger log.SQLLogger } // Open opens a database @@ -120,7 +123,24 @@ func (db *DB) reflectNew(typ reflect.Type) reflect.Value { // QueryContext overwrites sql.DB.QueryContext func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + start := time.Now() + if db.Logger != nil { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } rows, err := db.DB.QueryContext(ctx, query, args...) + if db.Logger != nil { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { if rows != nil { rows.Close() @@ -209,7 +229,7 @@ func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) if err != nil { return nil, err } - return db.DB.ExecContext(ctx, query, args...) + return db.ExecContext(ctx, query, args...) } func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { @@ -221,7 +241,29 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ if err != nil { return nil, err } - return db.DB.ExecContext(ctx, query, args...) + return db.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + start := time.Now() + if db.Logger != nil { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } + res, err := db.DB.ExecContext(ctx, query, args...) + if db.Logger != nil { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return res, err } func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { diff --git a/core/stmt.go b/core/stmt.go index 8a21541a..d3c46977 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -9,6 +9,9 @@ import ( "database/sql" "errors" "reflect" + "time" + + "xorm.io/xorm/log" ) // Stmt reprents a stmt objects @@ -16,6 +19,7 @@ type Stmt struct { *sql.Stmt db *DB names map[string]int + query string } func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { @@ -27,11 +31,27 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return "?" }) + start := time.Now() + if db.Logger != nil { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + }) + } stmt, err := db.DB.PrepareContext(ctx, query) + if db.Logger != nil { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { return nil, err } - return &Stmt{stmt, db, names}, nil + + return &Stmt{stmt, db, names, query}, nil } func (db *DB) Prepare(query string) (*Stmt, error) { @@ -48,7 +68,7 @@ func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, for k, i := range s.names { args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() } - return s.Stmt.ExecContext(ctx, args...) + return s.ExecContext(ctx, args...) } func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { @@ -65,15 +85,54 @@ func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Resul for k, i := range s.names { args[i] = vv.Elem().FieldByName(k).Interface() } - return s.Stmt.ExecContext(ctx, args...) + return s.ExecContext(ctx, args...) } func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { return s.ExecStructContext(context.Background(), st) } +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + start := time.Now() + if s.db.Logger != nil { + s.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + }) + } + res, err := s.Stmt.ExecContext(ctx, args) + if s.db.Logger != nil { + s.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return res, err +} + func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { + start := time.Now() + if s.db.Logger != nil { + s.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + }) + } rows, err := s.Stmt.QueryContext(ctx, args...) + if s.db.Logger != nil { + s.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: s.query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { return nil, err } diff --git a/core/tx.go b/core/tx.go index a56b7006..10022efc 100644 --- a/core/tx.go +++ b/core/tx.go @@ -7,6 +7,9 @@ package core import ( "context" "database/sql" + "time" + + "xorm.io/xorm/log" ) type Tx struct { @@ -15,7 +18,22 @@ type Tx struct { } func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + start := time.Now() + if db.Logger != nil { + db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: "BEGIN TRANSACTION", + }) + } tx, err := db.DB.BeginTx(ctx, opts) + if db.Logger != nil { + db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: "BEGIN TRANSACTION", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { return nil, err } @@ -23,11 +41,7 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { } func (db *DB) Begin() (*Tx, error) { - tx, err := db.DB.Begin() - if err != nil { - return nil, err - } - return &Tx{tx, db}, nil + return db.BeginTx(context.Background(), nil) } func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { @@ -39,11 +53,26 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return "?" }) + start := time.Now() + if tx.db.Logger != nil { + tx.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + }) + } stmt, err := tx.Tx.PrepareContext(ctx, query) + if tx.db.Logger != nil { + tx.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: "PREPARE", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { return nil, err } - return &Stmt{stmt, tx.db, names}, nil + return &Stmt{stmt, tx.db, names, query}, nil } func (tx *Tx) Prepare(query string) (*Stmt, error) { @@ -64,7 +93,7 @@ func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) if err != nil { return nil, err } - return tx.Tx.ExecContext(ctx, query, args...) + return tx.ExecContext(ctx, query, args...) } func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { @@ -76,7 +105,29 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ if err != nil { return nil, err } - return tx.Tx.ExecContext(ctx, query, args...) + return tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + start := time.Now() + if tx.db.Logger != nil { + tx.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } + res, err := tx.Tx.ExecContext(ctx, query, args...) + if tx.db.Logger != nil { + tx.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } + return res, err } func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { @@ -84,8 +135,28 @@ func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { } func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + start := time.Now() + if tx.db.Logger != nil { + tx.db.Logger.BeforeSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + }) + } rows, err := tx.Tx.QueryContext(ctx, query, args...) + if tx.db.Logger != nil { + tx.db.Logger.AfterSQL(log.LogContext{ + Ctx: ctx, + SQL: query, + Args: args, + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + } if err != nil { + if rows != nil { + rows.Close() + } return nil, err } return &Rows{rows, tx.db}, nil diff --git a/dialects/dialect.go b/dialects/dialect.go index e9e512ee..a0139d9f 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -11,14 +11,11 @@ import ( "time" "xorm.io/xorm/core" - "xorm.io/xorm/log" "xorm.io/xorm/schemas" ) -type DBType string - type URI struct { - DBType DBType + DBType schemas.DBType Proto string Host string Port string @@ -32,13 +29,12 @@ type URI struct { Schema string } -// a dialect is a driver's wrapper +// Dialect represents a kind of database type Dialect interface { - SetLogger(logger log.Logger) Init(*core.DB, *URI, string, string) error URI() *URI DB() *core.DB - DBType() DBType + DBType() schemas.DBType SQLType(*schemas.Column) string FormatBytes(b []byte) string DefaultSchema() string @@ -49,7 +45,6 @@ type Dialect interface { IsReserved(string) bool Quoter() schemas.Quoter - RollBackStr() string AutoIncrStr() string SupportInsertMany() bool @@ -92,7 +87,6 @@ type Base struct { dialect Dialect driverName string dataSourceName string - logger log.Logger uri *URI } @@ -100,10 +94,6 @@ func (b *Base) DB() *core.DB { return b.db } -func (b *Base) SetLogger(logger log.Logger) { - b.logger = logger -} - func (b *Base) DefaultSchema() string { return "" } @@ -118,7 +108,7 @@ func (b *Base) URI() *URI { return b.uri } -func (b *Base) DBType() DBType { +func (b *Base) DBType() schemas.DBType { return b.uri.DBType } @@ -187,10 +177,6 @@ func (b *Base) DataSourceName() string { return b.dataSourceName } -func (db *Base) RollBackStr() string { - return "ROLL BACK" -} - func (db *Base) SupportDropIfExists() bool { return true } @@ -201,7 +187,6 @@ func (db *Base) DropTableSQL(tableName string) string { } func (db *Base) HasRecords(ctx context.Context, query string, args ...interface{}) (bool, error) { - db.LogSQL(query, args) rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { return false, err @@ -229,13 +214,8 @@ func (db *Base) IsColumnExist(ctx context.Context, tableName, colName string) (b } func (db *Base) AddColumnSQL(tableName string, col *schemas.Column) string { - quoter := db.dialect.Quoter() - sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), + return fmt.Sprintf("ALTER TABLE %v ADD %v", db.dialect.Quoter().Quote(tableName), db.String(col)) - if db.dialect.DBType() == schemas.MYSQL && len(col.Comment) > 0 { - sql += " COMMENT '" + col.Comment + "'" - } - return sql } func (db *Base) CreateIndexSQL(tableName string, index *schemas.Index) string { @@ -323,16 +303,6 @@ func (b *Base) ForUpdateSQL(query string) string { return query + " FOR UPDATE" } -func (b *Base) LogSQL(sql string, args []interface{}) { - if b.logger != nil && b.logger.IsShowSQL() { - if len(args) > 0 { - b.logger.Infof("[SQL] %v %v", sql, args) - } else { - b.logger.Infof("[SQL] %v", sql) - } - } -} - func (b *Base) SetParams(params map[string]string) { } @@ -341,7 +311,7 @@ var ( ) // RegisterDialect register database dialect -func RegisterDialect(dbName DBType, dialectFunc func() Dialect) { +func RegisterDialect(dbName schemas.DBType, dialectFunc func() Dialect) { if dialectFunc == nil { panic("core: Register dialect is nil") } @@ -349,7 +319,7 @@ func RegisterDialect(dbName DBType, dialectFunc func() Dialect) { } // QueryDialect query if registered database dialect -func QueryDialect(dbName DBType) Dialect { +func QueryDialect(dbName schemas.DBType) Dialect { if d, ok := dialects[strings.ToLower(string(dbName))]; ok { return d() } @@ -358,7 +328,7 @@ func QueryDialect(dbName DBType) Dialect { func regDrvsNDialects() bool { providedDrvsNDialects := map[string]struct { - dbType DBType + dbType schemas.DBType getDriver func() Driver getDialect func() Dialect }{ diff --git a/dialects/mssql.go b/dialects/mssql.go index 83844f4e..9963fc4f 100644 --- a/dialects/mssql.go +++ b/dialects/mssql.go @@ -351,7 +351,6 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma LEFT OUTER JOIN sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id where a.object_id=object_id('` + tableName + `')` - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -411,7 +410,6 @@ func (db *mssql) GetColumns(ctx context.Context, tableName string) ([]string, ma func (db *mssql) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -446,7 +444,6 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { diff --git a/dialects/mysql.go b/dialects/mysql.go index 62fc6eb1..5ed2d8f1 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -303,23 +303,26 @@ func (db *mysql) IndexCheckSQL(tableName, idxName string) (string, []interface{} return sql, args } -/*func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{db.DbName, tableName, colName} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - return sql, args -}*/ - func (db *mysql) TableCheckSQL(tableName string) (string, []interface{}) { args := []interface{}{db.uri.DBName, tableName} sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" return sql, args } +func (db *mysql) AddColumnSQL(tableName string, col *schemas.Column) string { + quoter := db.dialect.Quoter() + sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quoter.Quote(tableName), + db.String(col)) + if len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" + } + return sql +} + func (db *mysql) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA`,`COLUMN_COMMENT` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -430,7 +433,6 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{db.uri.DBName} s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT`, `TABLE_COMMENT` from " + "`INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? AND (`ENGINE`='MyISAM' OR `ENGINE` = 'InnoDB' OR `ENGINE` = 'TokuDB')" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -459,7 +461,6 @@ func (db *mysql) GetTables(ctx context.Context) ([]*schemas.Table, error) { func (db *mysql) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{db.uri.DBName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { diff --git a/dialects/oracle.go b/dialects/oracle.go index 1247d7a4..3b9989d9 100644 --- a/dialects/oracle.go +++ b/dialects/oracle.go @@ -635,7 +635,6 @@ func (db *oracle) IsColumnExist(ctx context.Context, tableName, colName string) args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" - db.LogSQL(query, args) rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { @@ -653,7 +652,6 @@ func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, m args := []interface{}{tableName} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -750,7 +748,6 @@ func (db *oracle) GetColumns(ctx context.Context, tableName string) ([]string, m func (db *oracle) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -775,7 +772,6 @@ func (db *oracle) GetIndexes(ctx context.Context, tableName string) (map[string] args := []interface{}{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { diff --git a/dialects/postgres.go b/dialects/postgres.go index 94514e95..2e314812 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -943,7 +943,6 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" } - db.LogSQL(query, args) rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { @@ -975,8 +974,6 @@ WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.att } s = fmt.Sprintf(s, f) - db.LogSQL(s, args) - rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err @@ -1077,8 +1074,6 @@ func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) { s = s + " WHERE schemaname = $1" } - db.LogSQL(s, args) - rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, err @@ -1117,7 +1112,6 @@ func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[strin args = append(args, db.uri.Schema) s = s + " AND schemaname=$2" } - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { diff --git a/dialects/sqlite3.go b/dialects/sqlite3.go index 5511468f..7dfa7fca 100644 --- a/dialects/sqlite3.go +++ b/dialects/sqlite3.go @@ -249,16 +249,10 @@ func (db *sqlite3) ForUpdateSQL(query string) string { return query } -/*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName} - sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - return sql, args -}*/ - func (db *sqlite3) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { args := []interface{}{tableName} query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - db.LogSQL(query, args) + rows, err := db.DB().QueryContext(ctx, query, args...) if err != nil { return false, err @@ -336,7 +330,7 @@ func parseString(colStr string) (*schemas.Column, error) { func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - db.LogSQL(s, args) + rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { return nil, nil, err @@ -393,7 +387,6 @@ func (db *sqlite3) GetColumns(ctx context.Context, tableName string) ([]string, func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { @@ -419,7 +412,6 @@ func (db *sqlite3) GetTables(ctx context.Context) ([]*schemas.Table, error) { func (db *sqlite3) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - db.LogSQL(s, args) rows, err := db.DB().QueryContext(ctx, s, args...) if err != nil { diff --git a/engine.go b/engine.go index cf0126e9..b34f0716 100644 --- a/engine.go +++ b/engine.go @@ -31,22 +31,16 @@ import ( // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { - db *core.DB - dialect dialects.Dialect + cacherMgr *caches.Manager + db *core.DB + defaultContext context.Context + dialect dialects.Dialect + engineGroup *EngineGroup + logger log.ContextLogger + tagParser *tags.Parser - showSQL bool - showExecTime bool - - logger log.Logger TZLocation *time.Location // The timezone of the application DatabaseTZ *time.Location // The timezone of the database - - engineGroup *EngineGroup - - defaultContext context.Context - - tagParser *tags.Parser - cacherMgr *caches.Manager } func (engine *Engine) SetCacher(tableName string, cacher caches.Cacher) { @@ -67,32 +61,33 @@ func (engine *Engine) BufferSize(size int) *Session { // ShowSQL show SQL statement or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) - if len(show) == 0 { - engine.showSQL = true + if engine.logger.IsShowSQL() { + engine.db.Logger = engine.logger } else { - engine.showSQL = show[0] - } -} - -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -func (engine *Engine) ShowExecTime(show ...bool) { - if len(show) == 0 { - engine.showExecTime = true - } else { - engine.showExecTime = show[0] + engine.db.Logger = &log.DiscardSQLLogger{} } } // Logger return the logger interface -func (engine *Engine) Logger() log.Logger { +func (engine *Engine) Logger() log.ContextLogger { return engine.logger } // SetLogger set the new logger -func (engine *Engine) SetLogger(logger log.Logger) { - engine.logger = logger - engine.showSQL = logger.IsShowSQL() - engine.dialect.SetLogger(logger) +func (engine *Engine) SetLogger(logger interface{}) { + var realLogger log.ContextLogger + switch t := logger.(type) { + case log.Logger: + realLogger = log.NewLoggerAdapter(t) + case log.ContextLogger: + realLogger = t + } + engine.logger = realLogger + if realLogger.IsShowSQL() { + engine.db.Logger = realLogger + } else { + engine.db.Logger = &log.DiscardSQLLogger{} + } } // SetLogLevel sets the logger level @@ -123,12 +118,12 @@ func (engine *Engine) SetMapper(mapper names.Mapper) { // SetTableMapper set the table name mapping rule func (engine *Engine) SetTableMapper(mapper names.Mapper) { - engine.tagParser.TableMapper = mapper + engine.tagParser.SetTableMapper(mapper) } // SetColumnMapper set the column name mapping rule func (engine *Engine) SetColumnMapper(mapper names.Mapper) { - engine.tagParser.ColumnMapper = mapper + engine.tagParser.SetColumnMapper(mapper) } // SupportInsertMany If engine's database support batch insert records like @@ -255,17 +250,6 @@ func (engine *Engine) Ping() error { return session.Ping() } -// logSQL save sql -func (engine *Engine) logSQL(sqlStr string, sqlArgs ...interface{}) { - if engine.showSQL && !engine.showExecTime { - if len(sqlArgs) > 0 { - engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) - } else { - engine.logger.Infof("[SQL] %v", sqlStr) - } - } -} - // SQL method let's you manually write raw SQL and operate // For example: // @@ -336,7 +320,7 @@ func (engine *Engine) DBMetas() ([]*schemas.Table, error) { } // DumpAllToFile dump database all table structs and data to a file -func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error { +func (engine *Engine) DumpAllToFile(fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -346,7 +330,7 @@ func (engine *Engine) DumpAllToFile(fp string, tp ...dialects.DBType) error { } // DumpAll dump database all table structs and data to w -func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error { +func (engine *Engine) DumpAll(w io.Writer, tp ...schemas.DBType) error { tables, err := engine.DBMetas() if err != nil { return err @@ -355,7 +339,7 @@ func (engine *Engine) DumpAll(w io.Writer, tp ...dialects.DBType) error { } // DumpTablesToFile dump specified tables to SQL file. -func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...dialects.DBType) error { +func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp ...schemas.DBType) error { f, err := os.Create(fp) if err != nil { return err @@ -365,12 +349,12 @@ func (engine *Engine) DumpTablesToFile(tables []*schemas.Table, fp string, tp .. } // DumpTables dump specify tables to io.Writer -func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error { +func (engine *Engine) DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { return engine.dumpTables(tables, w, tp...) } // dumpTables dump database all table structs and data to w with specify db type -func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dialects.DBType) error { +func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error { var dialect dialects.Dialect var distDBName string if len(tp) == 0 { @@ -496,7 +480,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...dia } // FIXME: Hack for postgres - if string(dialect.DBType()) == schemas.POSTGRES && table.AutoIncrColumn() != nil { + if dialect.DBType() == schemas.POSTGRES && table.AutoIncrColumn() != nil { _, err = io.WriteString(w, "SELECT setval('"+table.Name+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dialect.Quoter().Quote(table.Name)+"), 1), false);\n") if err != nil { return err @@ -739,13 +723,9 @@ func (t *Table) IsValid() bool { } // TableInfo get table info according to bean's content -func (engine *Engine) TableInfo(bean interface{}) (*Table, error) { +func (engine *Engine) TableInfo(bean interface{}) (*schemas.Table, error) { v := utils.ReflectValue(bean) - tb, err := engine.tagParser.MapType(v) - if err != nil { - return nil, err - } - return &Table{tb, dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean)}, nil + return engine.tagParser.ParseWithCache(v) } // IsTableEmpty if a table has any reocrd @@ -763,7 +743,7 @@ func (engine *Engine) IsTableExist(beanOrTableName interface{}) (bool, error) { } // IDOf get id from one struct -func (engine *Engine) IDOf(bean interface{}) schemas.PK { +func (engine *Engine) IDOf(bean interface{}) (schemas.PK, error) { return engine.IDOfV(reflect.ValueOf(bean)) } @@ -773,18 +753,13 @@ func (engine *Engine) TableName(bean interface{}, includeSchema ...bool) string } // IDOfV get id from one value of struct -func (engine *Engine) IDOfV(rv reflect.Value) schemas.PK { - pk, err := engine.idOfV(rv) - if err != nil { - engine.logger.Error(err) - return nil - } - return pk +func (engine *Engine) IDOfV(rv reflect.Value) (schemas.PK, error) { + return engine.idOfV(rv) } func (engine *Engine) idOfV(rv reflect.Value) (schemas.PK, error) { v := reflect.Indirect(rv) - table, err := engine.tagParser.MapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return nil, err } @@ -882,7 +857,7 @@ func (engine *Engine) ClearCache(beans ...interface{}) error { // UnMapType remove table from tables cache func (engine *Engine) UnMapType(t reflect.Type) { - engine.tagParser.ClearTable(t) + engine.tagParser.ClearCacheTable(t) } // Sync the new struct changes to database, this method will automatically add @@ -895,7 +870,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { for _, bean := range beans { v := utils.ReflectValue(bean) tableNameNoSchema := dialects.FullTableName(engine.dialect, engine.GetTableMapper(), bean) - table, err := engine.tagParser.MapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } @@ -1216,8 +1191,7 @@ func (engine *Engine) Import(r io.Reader) ([]sql.Result, error) { for scanner.Scan() { query := strings.Trim(scanner.Text(), " \t\n\r") if len(query) > 0 { - engine.logSQL(query) - result, err := engine.DB().Exec(query) + result, err := engine.DB().ExecContext(engine.defaultContext, query) results = append(results, result) if err != nil { return nil, err @@ -1244,12 +1218,12 @@ func (engine *Engine) formatColTime(col *schemas.Column, t time.Time) (v interfa // GetColumnMapper returns the column name mapper func (engine *Engine) GetColumnMapper() names.Mapper { - return engine.tagParser.ColumnMapper + return engine.tagParser.GetColumnMapper() } // GetTableMapper returns the table name mapper func (engine *Engine) GetTableMapper() names.Mapper { - return engine.tagParser.TableMapper + return engine.tagParser.GetTableMapper() } // GetTZLocation returns time zone of the application diff --git a/engine_group.go b/engine_group.go index 55159d55..8177697e 100644 --- a/engine_group.go +++ b/engine_group.go @@ -135,7 +135,7 @@ func (eg *EngineGroup) SetDefaultCacher(cacher caches.Cacher) { } // SetLogger set the new logger -func (eg *EngineGroup) SetLogger(logger log.Logger) { +func (eg *EngineGroup) SetLogger(logger interface{}) { eg.Engine.SetLogger(logger) for i := 0; i < len(eg.slaves); i++ { eg.slaves[i].SetLogger(logger) @@ -188,14 +188,6 @@ func (eg *EngineGroup) SetTableMapper(mapper names.Mapper) { } } -// ShowExecTime show SQL statement and execute time or not on logger if log level is great than INFO -func (eg *EngineGroup) ShowExecTime(show ...bool) { - eg.Engine.ShowExecTime(show...) - for i := 0; i < len(eg.slaves); i++ { - eg.slaves[i].ShowExecTime(show...) - } -} - // ShowSQL show SQL statement or not on logger if log level is great than INFO func (eg *EngineGroup) ShowSQL(show ...bool) { eg.Engine.ShowSQL(show...) diff --git a/interface.go b/interface.go index e7894012..13f1e12a 100644 --- a/interface.go +++ b/interface.go @@ -83,7 +83,7 @@ type EngineInterface interface { DBMetas() ([]*schemas.Table, error) Dialect() dialects.Dialect DropTables(...interface{}) error - DumpAllToFile(fp string, tp ...dialects.DBType) error + DumpAllToFile(fp string, tp ...schemas.DBType) error GetCacher(string) caches.Cacher GetColumnMapper() names.Mapper GetDefaultCacher() caches.Cacher @@ -98,7 +98,7 @@ type EngineInterface interface { SetConnMaxLifetime(time.Duration) SetColumnMapper(names.Mapper) SetDefaultCacher(caches.Cacher) - SetLogger(logger log.Logger) + SetLogger(logger interface{}) SetLogLevel(log.LogLevel) SetMapper(names.Mapper) SetMaxOpenConns(int) @@ -107,12 +107,11 @@ type EngineInterface interface { SetTableMapper(names.Mapper) SetTZDatabase(tz *time.Location) SetTZLocation(tz *time.Location) - ShowExecTime(...bool) ShowSQL(show ...bool) Sync(...interface{}) error Sync2(...interface{}) error StoreEngine(storeEngine string) *Session - TableInfo(bean interface{}) (*Table, error) + TableInfo(bean interface{}) (*schemas.Table, error) TableName(interface{}, ...bool) string UnMapType(reflect.Type) } diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 92b1809a..68738b90 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -253,11 +253,11 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement func (statement *Statement) SetRefValue(v reflect.Value) error { var err error - statement.RefTable, err = statement.tagParser.MapType(reflect.Indirect(v)) + statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) if err != nil { return err } - statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, v, true) + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), v, true) return nil } @@ -267,11 +267,11 @@ func rValue(bean interface{}) reflect.Value { func (statement *Statement) SetRefBean(bean interface{}) error { var err error - statement.RefTable, err = statement.tagParser.MapType(rValue(bean)) + statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) if err != nil { return err } - statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, bean, true) + statement.tableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), bean, true) return nil } @@ -507,13 +507,13 @@ func (statement *Statement) SetTable(tableNameOrBean interface{}) error { t := v.Type() if t.Kind() == reflect.Struct { var err error - statement.RefTable, err = statement.tagParser.MapType(v) + statement.RefTable, err = statement.tagParser.ParseWithCache(v) if err != nil { return err } } - statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tableNameOrBean, true) + statement.AltTableName = dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tableNameOrBean, true) return nil } @@ -554,7 +554,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: - tbName := dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, tablename, true) + tbName := dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), tablename, true) if !utils.IsSubQuery(tbName) { var buf strings.Builder statement.dialect.Quoter().QuoteTo(&buf, tbName) @@ -689,7 +689,7 @@ func (statement *Statement) GenDelIndexSQL() []string { } else if index.Type == schemas.IndexType { rIdxName = utils.IndexName(idxPrefixName, idxName) } - sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.TableMapper, rIdxName, true))) + sql := fmt.Sprintf("DROP INDEX %v", statement.quote(dialects.FullTableName(statement.dialect, statement.tagParser.GetTableMapper(), rIdxName, true))) if statement.dialect.IndexOnTable() { sql += fmt.Sprintf(" ON %v", statement.quote(tbName)) } @@ -844,7 +844,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, val = bytes } } else { - table, err := statement.tagParser.MapType(fieldValue) + table, err := statement.tagParser.ParseWithCache(fieldValue) if err != nil { val = fieldValue.Interface() } else { diff --git a/internal/statements/update.go b/internal/statements/update.go index a5d7ec5a..e9cdd98c 100644 --- a/internal/statements/update.go +++ b/internal/statements/update.go @@ -187,7 +187,7 @@ func (statement *Statement) BuildUpdates(bean interface{}, val, _ = nulType.Value() } else { if !col.SQLType.IsJson() { - table, err := statement.tagParser.MapType(fieldValue) + table, err := statement.tagParser.ParseWithCache(fieldValue) if err != nil { val = fieldValue.Interface() } else { diff --git a/log/logger_context.go b/log/logger_context.go new file mode 100644 index 00000000..b05f1c52 --- /dev/null +++ b/log/logger_context.go @@ -0,0 +1,108 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package log + +import ( + "context" + "time" +) + +// LogContext represents a log context +type LogContext struct { + Ctx context.Context + SQL string // log content or SQL + Args []interface{} // if it's a SQL, it's the arguments + ExecuteTime time.Duration + Err error // SQL executed error +} + +type SQLLogger interface { + BeforeSQL(context LogContext) + AfterSQL(context LogContext) +} + +type DiscardSQLLogger struct{} + +var _ SQLLogger = &DiscardSQLLogger{} + +func (DiscardSQLLogger) BeforeSQL(LogContext) {} +func (DiscardSQLLogger) AfterSQL(LogContext) {} + +// ContextLogger represents a logger interface with context +type ContextLogger interface { + SQLLogger + + Debugf(format string, v ...interface{}) + Errorf(format string, v ...interface{}) + Infof(format string, v ...interface{}) + Warnf(format string, v ...interface{}) + + Level() LogLevel + SetLevel(l LogLevel) + + ShowSQL(show ...bool) + IsShowSQL() bool +} + +var ( + _ ContextLogger = &LoggerAdapter{} +) + +// LoggerAdapter wraps a Logger interafce as LoggerContext interface +type LoggerAdapter struct { + logger Logger +} + +func NewLoggerAdapter(logger Logger) ContextLogger { + return &LoggerAdapter{ + logger: logger, + } +} + +func (l *LoggerAdapter) BeforeSQL(ctx LogContext) {} + +func (l *LoggerAdapter) AfterSQL(ctx LogContext) { + if !l.logger.IsShowSQL() { + return + } + + if ctx.ExecuteTime > 0 { + l.logger.Infof("[SQL] %v %v - %v", ctx.SQL, ctx.Args, ctx.ExecuteTime) + } else { + l.logger.Infof("[SQL] %v %v", ctx.SQL, ctx.Args) + } +} + +func (l *LoggerAdapter) Debugf(format string, v ...interface{}) { + l.logger.Debugf(format, v...) +} + +func (l *LoggerAdapter) Errorf(format string, v ...interface{}) { + l.logger.Errorf(format, v...) +} + +func (l *LoggerAdapter) Infof(format string, v ...interface{}) { + l.logger.Infof(format, v...) +} + +func (l *LoggerAdapter) Warnf(format string, v ...interface{}) { + l.logger.Warnf(format, v...) +} + +func (l *LoggerAdapter) Level() LogLevel { + return l.logger.Level() +} + +func (l *LoggerAdapter) SetLevel(lv LogLevel) { + l.logger.SetLevel(lv) +} + +func (l *LoggerAdapter) ShowSQL(show ...bool) { + l.logger.ShowSQL(show...) +} + +func (l *LoggerAdapter) IsShowSQL() bool { + return l.logger.IsShowSQL() +} diff --git a/schemas/table.go b/schemas/table.go index 44aa8152..2dac3ea2 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -7,7 +7,6 @@ package schemas import ( "reflect" "strings" - //"xorm.io/xorm/cache" ) // Table represents a database table @@ -24,10 +23,9 @@ type Table struct { Updated string Deleted string Version string - //Cacher caches.Cacher - StoreEngine string - Charset string - Comment string + StoreEngine string + Charset string + Comment string } func NewEmptyTable() *Table { diff --git a/schemas/type.go b/schemas/type.go index 2aaa2a44..39f1bf4e 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -11,12 +11,14 @@ import ( "time" ) +type DBType string + const ( - POSTGRES = "postgres" - SQLITE = "sqlite3" - MYSQL = "mysql" - MSSQL = "mssql" - ORACLE = "oracle" + POSTGRES DBType = "postgres" + SQLITE DBType = "sqlite3" + MYSQL DBType = "mysql" + MSSQL DBType = "mssql" + ORACLE DBType = "oracle" ) // SQLType represents SQL types diff --git a/session.go b/session.go index 92063882..287465ca 100644 --- a/session.go +++ b/session.go @@ -82,7 +82,7 @@ func (session *Session) Init() { session.engine.DatabaseTZ, ) - session.showSQL = session.engine.showSQL + //session.showSQL = session.engine.showSQL session.isAutoCommit = true session.isCommitedOrRollbacked = false session.isAutoClose = false @@ -165,7 +165,7 @@ func (session *Session) After(closures func(interface{})) *Session { // Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { if err := session.statement.SetTable(tableNameOrBean); err != nil { - session.engine.logger.Error(err) + session.statement.LastError = err } return session } @@ -447,7 +447,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b fieldValue, err := session.getField(dataStruct, key, table, idx) if err != nil { if !strings.Contains(err.Error(), "is not valid") { - session.engine.logger.Warn(err) + session.engine.logger.Warnf("%v", err) } continue } @@ -650,7 +650,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.byte2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -659,7 +659,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b hasAssigned = true t, err := session.str2Time(col, d) if err != nil { - session.engine.logger.Error("byte2Time error:", err.Error()) + session.engine.logger.Errorf("byte2Time error: %v", err) hasAssigned = false } else { fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) @@ -672,7 +672,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b // !! 增加支持sql.Scanner接口的结构,如sql.NullString hasAssigned = true if err := nulVal.Scan(vv.Interface()); err != nil { - session.engine.logger.Error("sql.Sanner error:", err.Error()) + session.engine.logger.Errorf("sql.Sanner error: %v", err) hasAssigned = false } } else if col.SQLType.IsJson() { @@ -698,7 +698,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b } } } else if session.statement.UseCascade { - table, err := session.engine.tagParser.MapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return nil, err } @@ -865,17 +865,6 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b func (session *Session) saveLastSQL(sql string, args ...interface{}) { session.lastSQL = sql session.lastSQLArgs = args - session.logSQL(sql, args...) -} - -func (session *Session) logSQL(sqlStr string, sqlArgs ...interface{}) { - if session.showSQL && !session.engine.showExecTime { - if len(sqlArgs) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, sqlArgs) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } } // LastSQL returns last query information diff --git a/session_convert.go b/session_convert.go index 735aefa6..1cd00627 100644 --- a/session_convert.go +++ b/session_convert.go @@ -111,7 +111,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val if len(data) > 0 { err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -125,7 +124,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val if len(data) > 0 { err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -138,7 +136,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val if len(data) > 0 { err := json.DefaultJSONHandler.Unmarshal(data, x.Interface()) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(x.Elem()) @@ -210,7 +207,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val v = x fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) } else if session.statement.UseCascade { - table, err := session.engine.tagParser.MapType(*fieldValue) + table, err := session.engine.tagParser.ParseWithCache(*fieldValue) if err != nil { return err } @@ -267,7 +264,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val if len(data) > 0 { err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) @@ -278,7 +274,6 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val if len(data) > 0 { err := json.DefaultJSONHandler.Unmarshal(data, &x) if err != nil { - session.engine.logger.Error(err) return err } fieldValue.Set(reflect.ValueOf(&x).Convert(fieldType)) @@ -493,7 +488,7 @@ func (session *Session) bytes2Value(col *schemas.Column, fieldValue *reflect.Val default: if session.statement.UseCascade { structInter := reflect.New(fieldType.Elem()) - table, err := session.engine.tagParser.MapType(structInter.Elem()) + table, err := session.engine.tagParser.ParseWithCache(structInter.Elem()) if err != nil { return err } @@ -570,7 +565,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. if fieldValue.IsNil() { return nil, nil } else if !fieldValue.IsValid() { - session.engine.logger.Warn("the field[", col.FieldName, "] is invalid") + session.engine.logger.Warnf("the field [%s] is invalid", col.FieldName) return nil, nil } else { // !nashtsai! deference pointer type to instance type @@ -604,7 +599,7 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. return v.Value() } - fieldTable, err := session.engine.tagParser.MapType(fieldValue) + fieldTable, err := session.engine.tagParser.ParseWithCache(fieldValue) if err != nil { return nil, err } @@ -618,14 +613,12 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. if col.SQLType.IsText() { bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) return 0, err } return string(bytes), nil } else if col.SQLType.IsBlob() { bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) return 0, err } return bytes, nil @@ -634,7 +627,6 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. case reflect.Complex64, reflect.Complex128: bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) return 0, err } return string(bytes), nil @@ -646,7 +638,6 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. if col.SQLType.IsText() { bytes, err := json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) return 0, err } return string(bytes), nil @@ -659,7 +650,6 @@ func (session *Session) value2Interface(col *schemas.Column, fieldValue reflect. } else { bytes, err = json.DefaultJSONHandler.Marshal(fieldValue.Interface()) if err != nil { - session.engine.logger.Error(err) return 0, err } } diff --git a/session_delete.go b/session_delete.go index f21151e1..04200035 100644 --- a/session_delete.go +++ b/session_delete.go @@ -62,14 +62,14 @@ func (session *Session) cacheDelete(table *schemas.Table, tableName, sqlStr stri } for _, id := range ids { - session.engine.logger.Debug("[cacheDelete] delete cache obj:", tableName, id) + session.engine.logger.Debugf("[cache] delete cache obj: %v, %v", tableName, id) sid, err := id.ToString() if err != nil { return err } cacher.DelBean(tableName, sid) } - session.engine.logger.Debug("[cacheDelete] clear cache table:", tableName) + session.engine.logger.Debugf("[cache] clear cache table: %v", tableName) cacher.ClearIds(tableName) return nil } diff --git a/session_find.go b/session_find.go index a3ba2c82..9551b767 100644 --- a/session_find.go +++ b/session_find.go @@ -141,7 +141,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) return err } err = nil // !nashtsai! reset err to nil for ErrCacheFailed - session.engine.logger.Warn("Cache Find Failed") + session.engine.logger.Warnf("Cache Find Failed") } } @@ -225,7 +225,7 @@ func (session *Session) noCacheFind(table *schemas.Table, containerValue reflect if elemType.Kind() == reflect.Struct { var newValue = newElemFunc(fields) dataStruct := utils.ReflectValue(newValue.Interface()) - tb, err := session.engine.tagParser.MapType(dataStruct) + tb, err := session.engine.tagParser.ParseWithCache(dataStruct) if err != nil { return err } @@ -307,7 +307,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for rows.Next() { i++ if i > 500 { - session.engine.logger.Debug("[cacheFind] ids length > 500, no cache") + session.engine.logger.Debugf("[cacheFind] ids length > 500, no cache") return ErrCacheFailed } var res = make([]string, len(table.PrimaryKeys)) @@ -326,13 +326,13 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ids = append(ids, pk) } - session.engine.logger.Debug("[cacheFind] cache sql:", ids, tableName, sqlStr, newsql, args) + session.engine.logger.Debugf("[cache] cache sql: %v, %v, %v, %v, %v", ids, tableName, sqlStr, newsql, args) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return err } } else { - session.engine.logger.Debug("[cacheFind] cache hit sql:", tableName, sqlStr, newsql, args) + session.engine.logger.Debugf("[cache] cache hit sql: %v, %v, %v, %v", tableName, sqlStr, newsql, args) } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) @@ -365,16 +365,20 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in ides = append(ides, id) ididxes[sid] = idx } else { - session.engine.logger.Debug("[cacheFind] cache hit bean:", tableName, id, bean) + session.engine.logger.Debugf("[cache] cache hit bean: %v, %v, %v", tableName, id, bean) + + pk, err := session.engine.IDOf(bean) + if err != nil { + return err + } - pk := session.engine.IDOf(bean) xid, err := pk.ToString() if err != nil { return err } if sid != xid { - session.engine.logger.Error("[cacheFind] error cache", xid, sid, bean) + session.engine.logger.Errorf("[cache] error cache: %v, %v, %v", xid, sid, bean) return ErrCacheFailed } temps[idx] = bean @@ -424,7 +428,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in bean := rv.Interface() temps[ididxes[sid]] = bean - session.engine.logger.Debug("[cacheFind] cache bean:", tableName, id, bean, temps) + session.engine.logger.Debugf("[cache] cache bean: %v, %v, %v, %v", tableName, id, bean, temps) cacher.PutBean(tableName, sid, bean) } } @@ -432,7 +436,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for j := 0; j < len(temps); j++ { bean := temps[j] if bean == nil { - session.engine.logger.Warn("[cacheFind] cache no hit:", tableName, ids[j], temps) + session.engine.logger.Warnf("[cache] cache no hit: %v, %v, %v", tableName, ids[j], temps) // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead continue } diff --git a/session_get.go b/session_get.go index f0fc016b..c468b440 100644 --- a/session_get.go +++ b/session_get.go @@ -79,7 +79,7 @@ func (session *Session) get(bean interface{}) (bool, error) { if context != nil { res := context.Get(fmt.Sprintf("%v-%v", sqlStr, args)) if res != nil { - session.engine.logger.Debug("hit context cache", sqlStr) + session.engine.logger.Debugf("hit context cache: %s", sqlStr) structValue := reflect.Indirect(reflect.ValueOf(bean)) structValue.Set(reflect.Indirect(reflect.ValueOf(res))) @@ -283,7 +283,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf tableName := session.statement.TableName() cacher := session.engine.cacherMgr.GetCacher(tableName) - session.engine.logger.Debug("[cache] Get SQL:", newsql, args) + session.engine.logger.Debugf("[cache] Get SQL: %s, %v", newsql, args) table := session.statement.RefTable ids, err := caches.GetCacheSql(cacher, tableName, newsql, args) if err != nil { @@ -319,19 +319,19 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } ids = []schemas.PK{pk} - session.engine.logger.Debug("[cache] cache ids:", newsql, ids) + session.engine.logger.Debugf("[cache] cache ids: %s, %v", newsql, ids) err = caches.PutCacheSql(cacher, ids, tableName, newsql, args) if err != nil { return false, err } } else { - session.engine.logger.Debug("[cache] cache hit:", newsql, ids) + session.engine.logger.Debugf("[cache] cache hit: %s, %v", newsql, ids) } if len(ids) > 0 { structValue := reflect.Indirect(reflect.ValueOf(bean)) id := ids[0] - session.engine.logger.Debug("[cache] get bean:", tableName, id) + session.engine.logger.Debugf("[cache] get bean: %s, %v", tableName, id) sid, err := id.ToString() if err != nil { return false, err @@ -344,10 +344,10 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf return has, err } - session.engine.logger.Debug("[cache] cache bean:", tableName, id, cacheBean) + session.engine.logger.Debugf("[cache] cache bean: %s, %v, %v", tableName, id, cacheBean) cacher.PutBean(tableName, sid, cacheBean) } else { - session.engine.logger.Debug("[cache] cache hit:", tableName, id, cacheBean) + session.engine.logger.Debugf("[cache] cache hit: %s, %v, %v", tableName, id, cacheBean) has = true } structValue.Set(reflect.Indirect(reflect.ValueOf(cacheBean))) diff --git a/session_insert.go b/session_insert.go index 2206ad05..4662e25a 100644 --- a/session_insert.go +++ b/session_insert.go @@ -485,7 +485,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -503,7 +503,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -526,7 +526,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -544,7 +544,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -567,7 +567,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { if table.Version != "" && session.statement.CheckVersion { verValue, err := table.VersionColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else if verValue.IsValid() && verValue.CanSet() { session.incrVersionFieldValue(verValue) } @@ -585,7 +585,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue, err := table.AutoIncrColumn().ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() { @@ -617,7 +617,7 @@ func (session *Session) cacheInsert(table string) error { if cacher == nil { return nil } - session.engine.logger.Debug("[cache] clear sql:", table) + session.engine.logger.Debugf("[cache] clear sql: %v", table) cacher.ClearIds(table) return nil } diff --git a/session_raw.go b/session_raw.go index efd74710..02dcbf56 100644 --- a/session_raw.go +++ b/session_raw.go @@ -7,7 +7,6 @@ package xorm import ( "database/sql" "reflect" - "time" "xorm.io/xorm/core" "xorm.io/xorm/internal/statements" @@ -27,27 +26,8 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row session.queryPreprocess(&sqlStr, args...) - if session.showSQL { - session.lastSQL = sqlStr - session.lastSQLArgs = args - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration) - } else { - session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) - } - }() - } else { - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } - } + session.lastSQL = sqlStr + session.lastSQLArgs = args if session.isAutoCommit { var db *core.DB @@ -156,25 +136,8 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er session.queryPreprocess(&sqlStr, args...) - if session.engine.showSQL { - if session.engine.showExecTime { - b4ExecTime := time.Now() - defer func() { - execDuration := time.Since(b4ExecTime) - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %s %#v - took: %v", sqlStr, args, execDuration) - } else { - session.engine.logger.Infof("[SQL] %s - took: %v", sqlStr, execDuration) - } - }() - } else { - if len(args) > 0 { - session.engine.logger.Infof("[SQL] %v %#v", sqlStr, args) - } else { - session.engine.logger.Infof("[SQL] %v", sqlStr) - } - } - } + session.lastSQL = sqlStr + session.lastSQLArgs = args if !session.isAutoCommit { return session.tx.ExecContext(session.ctx, sqlStr, args...) diff --git a/session_schema.go b/session_schema.go index 0279ced7..3617a6b8 100644 --- a/session_schema.go +++ b/session_schema.go @@ -242,7 +242,7 @@ func (session *Session) Sync2(beans ...interface{}) error { for _, bean := range beans { v := utils.ReflectValue(bean) - table, err := engine.tagParser.MapType(v) + table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } diff --git a/session_tx.go b/session_tx.go index ee3d473f..7a4861c6 100644 --- a/session_tx.go +++ b/session_tx.go @@ -4,6 +4,12 @@ package xorm +import ( + "time" + + "xorm.io/xorm/log" +) + // Begin a transaction func (session *Session) Begin() error { if session.isAutoCommit { @@ -14,6 +20,7 @@ func (session *Session) Begin() error { session.isAutoCommit = false session.isCommitedOrRollbacked = false session.tx = tx + session.saveLastSQL("BEGIN TRANSACTION") } return nil @@ -22,10 +29,23 @@ func (session *Session) Begin() error { // Rollback When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.isAutoCommit && !session.isCommitedOrRollbacked { - session.saveLastSQL(session.engine.dialect.RollBackStr()) + session.saveLastSQL("ROLL BACK") session.isCommitedOrRollbacked = true session.isAutoCommit = true - return session.tx.Rollback() + + start := time.Now() + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + }) + err := session.tx.Rollback() + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "ROLL BACK", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) + return err } return nil } @@ -36,48 +56,62 @@ func (session *Session) Commit() error { session.saveLastSQL("COMMIT") session.isCommitedOrRollbacked = true session.isAutoCommit = true - var err error - if err = session.tx.Commit(); err == nil { - // handle processors after tx committed - closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { - if closuresPtr != nil { - for _, closure := range *closuresPtr { - closure(bean) - } - } - } - for bean, closuresPtr := range session.afterInsertBeans { - closureCallFunc(closuresPtr, bean) + start := time.Now() + session.engine.logger.BeforeSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + }) + err := session.tx.Commit() + session.engine.logger.AfterSQL(log.LogContext{ + Ctx: session.ctx, + SQL: "COMMIT", + ExecuteTime: time.Now().Sub(start), + Err: err, + }) - if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { - processor.AfterInsert() - } - } - for bean, closuresPtr := range session.afterUpdateBeans { - closureCallFunc(closuresPtr, bean) - - if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - processor.AfterUpdate() - } - } - for bean, closuresPtr := range session.afterDeleteBeans { - closureCallFunc(closuresPtr, bean) - - if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { - processor.AfterDelete() - } - } - cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { - if len(*slices) > 0 { - *slices = make(map[interface{}]*[]func(interface{}), 0) - } - } - cleanUpFunc(&session.afterInsertBeans) - cleanUpFunc(&session.afterUpdateBeans) - cleanUpFunc(&session.afterDeleteBeans) + if err != nil { + return err } - return err + + // handle processors after tx committed + closureCallFunc := func(closuresPtr *[]func(interface{}), bean interface{}) { + if closuresPtr != nil { + for _, closure := range *closuresPtr { + closure(bean) + } + } + } + + for bean, closuresPtr := range session.afterInsertBeans { + closureCallFunc(closuresPtr, bean) + + if processor, ok := interface{}(bean).(AfterInsertProcessor); ok { + processor.AfterInsert() + } + } + for bean, closuresPtr := range session.afterUpdateBeans { + closureCallFunc(closuresPtr, bean) + + if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { + processor.AfterUpdate() + } + } + for bean, closuresPtr := range session.afterDeleteBeans { + closureCallFunc(closuresPtr, bean) + + if processor, ok := interface{}(bean).(AfterDeleteProcessor); ok { + processor.AfterDelete() + } + } + cleanUpFunc := func(slices *map[interface{}]*[]func(interface{})) { + if len(*slices) > 0 { + *slices = make(map[interface{}]*[]func(interface{}), 0) + } + } + cleanUpFunc(&session.afterInsertBeans) + cleanUpFunc(&session.afterUpdateBeans) + cleanUpFunc(&session.afterDeleteBeans) } return nil } diff --git a/session_update.go b/session_update.go index bb53c3a1..551b8167 100644 --- a/session_update.go +++ b/session_update.go @@ -30,7 +30,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri for _, filter := range session.engine.dialect.Filters() { newsql = filter.Do(newsql) } - session.engine.logger.Debug("[cacheUpdate] new sql", oldhead, newsql) + session.engine.logger.Debugf("[cache] new sql: %v, %v", oldhead, newsql) var nStart int if len(args) > 0 { @@ -43,7 +43,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri } cacher := session.engine.GetCacher(tableName) - session.engine.logger.Debug("[cacheUpdate] get cache sql", newsql, args[nStart:]) + session.engine.logger.Debugf("[cache] get cache sql: %v, %v", newsql, args[nStart:]) ids, err := caches.GetCacheSql(cacher, tableName, newsql, args[nStart:]) if err != nil { rows, err := session.NoCache().queryRows(newsql, args[nStart:]...) @@ -76,7 +76,7 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri ids = append(ids, pk) } - session.engine.logger.Debug("[cacheUpdate] find updated id", ids) + session.engine.logger.Debugf("[cache] find updated id: %v", ids) } /*else { session.engine.LogDebug("[xorm:cacheUpdate] del cached sql:", tableName, newsql, args) cacher.DelIds(tableName, genSqlKey(newsql, args)) @@ -109,9 +109,9 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri if col := table.GetColumn(colName); col != nil { fieldValue, err := col.ValueOf(bean) if err != nil { - session.engine.logger.Error(err) + session.engine.logger.Errorf("%v", err) } else { - session.engine.logger.Debug("[cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) + session.engine.logger.Debugf("[cache] set bean field: %v, %v, %v", bean, colName, fieldValue.Interface()) if col.IsVersion && session.statement.CheckVersion { session.incrVersionFieldValue(fieldValue) } else { @@ -119,16 +119,16 @@ func (session *Session) cacheUpdate(table *schemas.Table, tableName, sqlStr stri } } } else { - session.engine.logger.Errorf("[cacheUpdate] ERROR: column %v is not table %v's", + session.engine.logger.Errorf("[cache] ERROR: column %v is not table %v's", colName, table.Name) } } - session.engine.logger.Debug("[cacheUpdate] update cache", tableName, id, bean) + session.engine.logger.Debugf("[cache] update cache: %v, %v, %v", tableName, id, bean) cacher.PutBean(tableName, sid, bean) } } - session.engine.logger.Debug("[cacheUpdate] clear cached table sql:", tableName) + session.engine.logger.Debugf("[cache] clear cached table sql: %v", tableName) cacher.ClearIds(tableName) return nil } @@ -414,7 +414,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if cacher := session.engine.GetCacher(tableName); cacher != nil && session.statement.UseCache { // session.cacheUpdate(table, tableName, sqlStr, args...) - session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) + session.engine.logger.Debugf("[cache] clear table: %v", tableName) cacher.ClearIds(tableName) cacher.ClearBeans(tableName) } @@ -425,7 +425,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 closure(bean) } if processor, ok := interface{}(bean).(AfterUpdateProcessor); ok { - session.engine.logger.Debug("[event]", tableName, " has after update processor") + session.engine.logger.Debugf("[event] %v has after update processor", tableName) processor.AfterUpdate() } } else { diff --git a/tags/parser.go b/tags/parser.go index 5c94c55b..236d2d46 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -20,11 +20,15 @@ import ( "xorm.io/xorm/schemas" ) +var ( + ErrUnsupportedType = errors.New("Unsupported type") +) + type Parser struct { identifier string dialect dialects.Dialect - ColumnMapper names.Mapper - TableMapper names.Mapper + columnMapper names.Mapper + tableMapper names.Mapper handlers map[string]Handler cacherMgr *caches.Manager tableCache sync.Map // map[reflect.Type]*schemas.Table @@ -34,33 +38,39 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM return &Parser{ identifier: identifier, dialect: dialect, - TableMapper: tableMapper, - ColumnMapper: columnMapper, + tableMapper: tableMapper, + columnMapper: columnMapper, handlers: defaultTagHandlers, cacherMgr: cacherMgr, } } -func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) { - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col.Name) - col.Indexes[index.Name] = indexType - } else { - index := schemas.NewIndex(indexName, indexType) - index.AddColumn(col.Name) - table.AddIndex(index) - col.Indexes[index.Name] = indexType - } +func (parser *Parser) GetTableMapper() names.Mapper { + return parser.tableMapper } -func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { +func (parser *Parser) SetTableMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.tableMapper = mapper +} + +func (parser *Parser) GetColumnMapper() names.Mapper { + return parser.columnMapper +} + +func (parser *Parser) SetColumnMapper(mapper names.Mapper) { + parser.ClearCaches() + parser.columnMapper = mapper +} + +func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { t := v.Type() tableI, ok := parser.tableCache.Load(t) if ok { return tableI.(*schemas.Table), nil } - table, err := parser.mapType(v) + table, err := parser.Parse(v) if err != nil { return nil, err } @@ -78,16 +88,41 @@ func (parser *Parser) MapType(v reflect.Value) (*schemas.Table, error) { return table, nil } -// ClearTable removes the database mapper of a type from the cache -func (parser *Parser) ClearTable(t reflect.Type) { +// ClearCacheTable removes the database mapper of a type from the cache +func (parser *Parser) ClearCacheTable(t reflect.Type) { parser.tableCache.Delete(t) } -func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { +// ClearCaches removes all the cached table information parsed by structs +func (parser *Parser) ClearCaches() { + parser.tableCache = sync.Map{} +} + +func addIndex(indexName string, table *schemas.Table, col *schemas.Column, indexType int) { + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = indexType + } else { + index := schemas.NewIndex(indexName, indexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = indexType + } +} + +// Parse parses a struct as a table information +func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { t := v.Type() + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, ErrUnsupportedType + } + table := schemas.NewEmptyTable() table.Type = t - table.Name = names.GetTableName(parser.TableMapper, v) + table.Name = names.GetTableName(parser.tableMapper, v) var idFieldColName string var hasCacheTag, hasNoCacheTag bool @@ -204,7 +239,7 @@ func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { col.Length2 = col.SQLType.DefaultLength2 } if col.Name == "" { - col.Name = parser.ColumnMapper.Obj2Table(t.Field(i).Name) + col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name) } if ctx.isUnique { @@ -229,7 +264,7 @@ func (parser *Parser) mapType(v reflect.Value) (*schemas.Table, error) { } else { sqlType = schemas.Type2SQLType(fieldType) } - col = schemas.NewColumn(parser.ColumnMapper.Obj2Table(t.Field(i).Name), + col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true) diff --git a/tags/parser_test.go b/tags/parser_test.go new file mode 100644 index 00000000..6065bf2e --- /dev/null +++ b/tags/parser_test.go @@ -0,0 +1,44 @@ +// Copyright 2020 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tags + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "xorm.io/xorm/caches" + "xorm.io/xorm/dialects" + "xorm.io/xorm/names" +) + +type ParseTableName1 struct{} + +type ParseTableName2 struct{} + +func (p ParseTableName2) TableName() string { + return "p_parseTableName" +} + +func TestParseTableName(t *testing.T) { + parser := NewParser( + "xorm", + dialects.QueryDialect("mysql"), + names.SnakeMapper{}, + names.SnakeMapper{}, + caches.NewManager(), + ) + table, err := parser.Parse(reflect.ValueOf(new(ParseTableName1))) + assert.NoError(t, err) + assert.EqualValues(t, "parse_table_name1", table.Name) + + table, err = parser.Parse(reflect.ValueOf(new(ParseTableName2))) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) + + table, err = parser.Parse(reflect.ValueOf(ParseTableName2{})) + assert.NoError(t, err) + assert.EqualValues(t, "p_parseTableName", table.Name) +} diff --git a/tags/tag.go b/tags/tag.go index a043ed77..ee3f1e82 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -280,7 +280,7 @@ func ExtendsTagHandler(ctx *Context) error { isPtr = true fallthrough case reflect.Struct: - parentTable, err := ctx.parser.mapType(fieldValue) + parentTable, err := ctx.parser.Parse(fieldValue) if err != nil { return err } diff --git a/tags_test.go b/tags_test.go index 9d41a5fa..775fcf60 100644 --- a/tags_test.go +++ b/tags_test.go @@ -871,7 +871,7 @@ func TestAutoIncrTag(t *testing.T) { func TestTagComment(t *testing.T) { assert.NoError(t, prepareEngine()) // FIXME: only support mysql - if testEngine.Dialect().DriverName() != schemas.MYSQL { + if testEngine.Dialect().DBType() != schemas.MYSQL { return } diff --git a/xorm.go b/xorm.go index 2946b7c9..724a37cb 100644 --- a/xorm.go +++ b/xorm.go @@ -80,7 +80,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { logger := log.NewSimpleLogger(os.Stdout) logger.SetLevel(log.LOG_INFO) - engine.SetLogger(logger) + engine.SetLogger(log.NewLoggerAdapter(logger)) runtime.SetFinalizer(engine, close) diff --git a/xorm_test.go b/xorm_test.go index 2a24edb3..59f6c1a9 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -47,7 +47,7 @@ func createEngine(dbType, connStr string) error { var err error if !*cluster { - switch strings.ToLower(dbType) { + switch schemas.DBType(strings.ToLower(dbType)) { case schemas.MSSQL: db, err := sql.Open(dbType, strings.Replace(connStr, "xorm_test", "master", -1)) if err != nil {