From 7d2967c78698d206727517c4c3ea318c1564f2c7 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Tue, 29 Mar 2016 09:17:06 +0800 Subject: [PATCH] join parameters support & many comments --- VERSION | 2 +- engine.go | 58 ++++++++++++++++++++++++++++++++-------------------- logger.go | 16 +++++++++++++++ session.go | 31 ++++++++++++++-------------- statement.go | 42 +++++++++++++++++++++++++++---------- xorm.go | 12 +++++------ 6 files changed, 106 insertions(+), 55 deletions(-) diff --git a/VERSION b/VERSION index d3bc91a8..28d370f2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.5.2.0324 +xorm v0.5.2.0329 diff --git a/engine.go b/engine.go index abf8a7b2..b8d6f95c 100644 --- a/engine.go +++ b/engine.go @@ -54,6 +54,7 @@ type Engine struct { disableGlobalCache bool } +// ShowSQL show SQL statment or not on logger if log level is great than INFO func (engine *Engine) ShowSQL(show ...bool) { engine.logger.ShowSQL(show...) if len(show) == 0 { @@ -63,6 +64,7 @@ func (engine *Engine) ShowSQL(show ...bool) { } } +// ShowExecTime show SQL statment and execute time or not on logger if log level is great than INFO func (engine *Engine) ShowExecTime(show ...bool) { if len(show) == 0 { engine.showExecTime = true @@ -71,43 +73,51 @@ func (engine *Engine) ShowExecTime(show ...bool) { } } +// Logger return the logger interface func (engine *Engine) Logger() core.ILogger { return engine.logger } +// SetLogger set the new logger func (engine *Engine) SetLogger(logger core.ILogger) { engine.logger = logger engine.dialect.SetLogger(logger) } +// SetDisableGlobalCache disable global cache or not func (engine *Engine) SetDisableGlobalCache(disable bool) { if engine.disableGlobalCache != disable { engine.disableGlobalCache = disable } } +// DriverName return the current sql driver's name func (engine *Engine) DriverName() string { return engine.dialect.DriverName() } +// DataSourceName return the current connection string func (engine *Engine) DataSourceName() string { return engine.dialect.DataSourceName() } +// SetMapper set the name mapping rules func (engine *Engine) SetMapper(mapper core.IMapper) { engine.SetTableMapper(mapper) engine.SetColumnMapper(mapper) } +// SetTableMapper set the table name mapping rule func (engine *Engine) SetTableMapper(mapper core.IMapper) { engine.TableMapper = mapper } +// SetColumnMapper set the column name mapping rule func (engine *Engine) SetColumnMapper(mapper core.IMapper) { engine.ColumnMapper = mapper } -// If engine's database support batch insert records like +// SupportInsertMany If engine's database support batch insert records like // "insert into user values (name, age), (name, age)". // When the return is ture, then engine.Insert(&users) will // generate batch sql and exeute. @@ -115,13 +125,13 @@ func (engine *Engine) SupportInsertMany() bool { return engine.dialect.SupportInsertMany() } -// Engine's database use which charactor as quote. +// QuoteStr Engine's database use which charactor as quote. // mysql, sqlite use ` and postgres use " func (engine *Engine) QuoteStr() string { return engine.dialect.QuoteStr() } -// Use QuoteStr quote the string sql +// Quote Use QuoteStr quote the string sql func (engine *Engine) Quote(sql string) string { return engine.quoteTable(sql) } @@ -160,12 +170,12 @@ func (engine *Engine) quoteTable(keyName string) string { return engine.dialect.QuoteStr() + keyName + engine.dialect.QuoteStr() } -// A simple wrapper to dialect's core.SqlType method +// SqlType A simple wrapper to dialect's core.SqlType method func (engine *Engine) SqlType(c *core.Column) string { return engine.dialect.SqlType(c) } -// Database's autoincrement statement +// AutoIncrStr Database's autoincrement statement func (engine *Engine) AutoIncrStr() string { return engine.dialect.AutoIncrStr() } @@ -175,22 +185,17 @@ func (engine *Engine) SetMaxOpenConns(conns int) { engine.db.SetMaxOpenConns(conns) } -// @Deprecated -func (engine *Engine) SetMaxConns(conns int) { - engine.SetMaxOpenConns(conns) -} - -// SetMaxIdleConns +// SetMaxIdleConns set the max idle connections on pool, default is 2 func (engine *Engine) SetMaxIdleConns(conns int) { engine.db.SetMaxIdleConns(conns) } -// SetDefaltCacher set the default cacher. Xorm's default not enable cacher. +// SetDefaultCacher set the default cacher. Xorm's default not enable cacher. func (engine *Engine) SetDefaultCacher(cacher core.Cacher) { engine.Cacher = cacher } -// If you has set default cacher, and you want temporilly stop use cache, +// NoCache If you has set default cacher, and you want temporilly stop use cache, // you can use NoCache() func (engine *Engine) NoCache() *Session { session := engine.NewSession() @@ -198,13 +203,14 @@ func (engine *Engine) NoCache() *Session { return session.NoCache() } +// NoCascade If you do not want to auto cascade load object func (engine *Engine) NoCascade() *Session { session := engine.NewSession() session.IsAutoClose = true return session.NoCascade() } -// Set a table use a special cacher +// MapCacher Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) { v := rValue(bean) tb := engine.autoMapType(v) @@ -216,15 +222,17 @@ func (engine *Engine) NewDB() (*core.DB, error) { return core.OpenDialect(engine.dialect) } +// DB return the wrapper of sql.DB func (engine *Engine) DB() *core.DB { return engine.db } +// Dialect return database dialect func (engine *Engine) Dialect() core.Dialect { return engine.dialect } -// New a session +// NewSession New a session func (engine *Engine) NewSession() *Session { session := &Session{Engine: engine} session.Init() @@ -287,37 +295,42 @@ func (engine *Engine) logSQLExecutionTime(sqlStr string, args []interface{}, exe } } +// LogError logging error func (engine *Engine) LogError(contents ...interface{}) { engine.logger.Err(contents...) } +// LogErrorf logging errorf func (engine *Engine) LogErrorf(format string, contents ...interface{}) { engine.logger.Errf(format, contents...) } -// logging info +// LogInfo logging info func (engine *Engine) LogInfo(contents ...interface{}) { engine.logger.Info(contents...) } +// LogInfof logging infof func (engine *Engine) LogInfof(format string, contents ...interface{}) { engine.logger.Infof(format, contents...) } -// logging debug +// LogDebug logging debug func (engine *Engine) LogDebug(contents ...interface{}) { engine.logger.Debug(contents...) } +// LogDebugf logging debugf func (engine *Engine) LogDebugf(format string, contents ...interface{}) { engine.logger.Debugf(format, contents...) } -// logging warn +// LogWarn logging warn func (engine *Engine) LogWarn(contents ...interface{}) { engine.logger.Warning(contents...) } +// LogWarnf logging warnf func (engine *Engine) LogWarnf(format string, contents ...interface{}) { engine.logger.Warningf(format, contents...) } @@ -335,7 +348,7 @@ func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { return session.Sql(querystring, args...) } -// Default if your struct has "created" or "updated" filed tag, the fields +// NoAutoTime Default if your struct has "created" or "updated" filed tag, the fields // will automatically be filled with current time when Insert or Update // invoked. Call NoAutoTime if you dont' want to fill automatically. func (engine *Engine) NoAutoTime() *Session { @@ -344,13 +357,14 @@ func (engine *Engine) NoAutoTime() *Session { return session.NoAutoTime() } +// NoAutoCondition disable auto generate Where condition from bean or not func (engine *Engine) NoAutoCondition(no ...bool) *Session { session := engine.NewSession() session.IsAutoClose = true return session.NoAutoCondition(no...) } -// Retrieve all tables, columns, indexes' informations from database. +// DBMetas Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*core.Table, error) { tables, err := engine.dialect.GetTables() if err != nil { @@ -818,10 +832,10 @@ func (engine *Engine) OrderBy(order string) *Session { } // The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (engine *Engine) Join(join_operator string, tablename interface{}, condition string) *Session { +func (engine *Engine) Join(join_operator string, tablename interface{}, condition string, args ...interface{}) *Session { session := engine.NewSession() session.IsAutoClose = true - return session.Join(join_operator, tablename, condition) + return session.Join(join_operator, tablename, condition, args...) } // Generate Group By statement diff --git a/logger.go b/logger.go index 5302ca36..2a6f34ef 100644 --- a/logger.go +++ b/logger.go @@ -18,6 +18,7 @@ const ( DEFAULT_LOG_LEVEL = core.LOG_DEBUG ) +// SimpleLogger is the default implment of core.ILogger type SimpleLogger struct { DEBUG *log.Logger ERR *log.Logger @@ -29,14 +30,17 @@ type SimpleLogger struct { var _ core.ILogger = &SimpleLogger{} +// NewSimpleLogger use a special io.Writer as logger output func NewSimpleLogger(out io.Writer) *SimpleLogger { return NewSimpleLogger2(out, DEFAULT_LOG_PREFIX, DEFAULT_LOG_FLAG) } +// NewSimpleLogger2 let you customrize your logger prefix and flag func NewSimpleLogger2(out io.Writer, prefix string, flag int) *SimpleLogger { return NewSimpleLogger3(out, prefix, flag, DEFAULT_LOG_LEVEL) } +// NewSimpleLogger3 let you customrize your logger prefix and flag and logLevel func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) *SimpleLogger { return &SimpleLogger{ DEBUG: log.New(out, fmt.Sprintf("%s [debug] ", prefix), flag), @@ -47,6 +51,7 @@ func NewSimpleLogger3(out io.Writer, prefix string, flag int, l core.LogLevel) * } } +// Err implement core.ILogger func (s *SimpleLogger) Err(v ...interface{}) (err error) { if s.level <= core.LOG_ERR { s.ERR.Println(v...) @@ -54,6 +59,7 @@ func (s *SimpleLogger) Err(v ...interface{}) (err error) { return } +// Errf implement core.ILogger func (s *SimpleLogger) Errf(format string, v ...interface{}) (err error) { if s.level <= core.LOG_ERR { s.ERR.Printf(format, v...) @@ -61,6 +67,7 @@ func (s *SimpleLogger) Errf(format string, v ...interface{}) (err error) { return } +// Debug implement core.ILogger func (s *SimpleLogger) Debug(v ...interface{}) (err error) { if s.level <= core.LOG_DEBUG { s.DEBUG.Println(v...) @@ -68,6 +75,7 @@ func (s *SimpleLogger) Debug(v ...interface{}) (err error) { return } +// Debugf implement core.ILogger func (s *SimpleLogger) Debugf(format string, v ...interface{}) (err error) { if s.level <= core.LOG_DEBUG { s.DEBUG.Printf(format, v...) @@ -75,6 +83,7 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) (err error) { return } +// Info implement core.ILogger func (s *SimpleLogger) Info(v ...interface{}) (err error) { if s.level <= core.LOG_INFO { s.INFO.Println(v...) @@ -82,6 +91,7 @@ func (s *SimpleLogger) Info(v ...interface{}) (err error) { return } +// Infof implement core.ILogger func (s *SimpleLogger) Infof(format string, v ...interface{}) (err error) { if s.level <= core.LOG_INFO { s.INFO.Printf(format, v...) @@ -89,6 +99,7 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) (err error) { return } +// Warning implement core.ILogger func (s *SimpleLogger) Warning(v ...interface{}) (err error) { if s.level <= core.LOG_WARNING { s.WARN.Println(v...) @@ -96,6 +107,7 @@ func (s *SimpleLogger) Warning(v ...interface{}) (err error) { return } +// Warningf implement core.ILogger func (s *SimpleLogger) Warningf(format string, v ...interface{}) (err error) { if s.level <= core.LOG_WARNING { s.WARN.Printf(format, v...) @@ -103,15 +115,18 @@ func (s *SimpleLogger) Warningf(format string, v ...interface{}) (err error) { return } +// Level implement core.ILogger func (s *SimpleLogger) Level() core.LogLevel { return s.level } +// SetLevel implement core.ILogger func (s *SimpleLogger) SetLevel(l core.LogLevel) (err error) { s.level = l return } +// ShowSQL implement core.ILogger func (s *SimpleLogger) ShowSQL(show ...bool) { if len(show) == 0 { s.showSQL = true @@ -120,6 +135,7 @@ func (s *SimpleLogger) ShowSQL(show ...bool) { s.showSQL = show[0] } +// IsShowSQL implement core.ILogger func (s *SimpleLogger) IsShowSQL() bool { return s.showSQL } diff --git a/session.go b/session.go index f2650f9e..278c9263 100644 --- a/session.go +++ b/session.go @@ -309,31 +309,32 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { return session } -// Method NoCache ask this session do not retrieve data from cache system and +// NoCache ask this session do not retrieve data from cache system and // get data from database directly. func (session *Session) NoCache() *Session { session.Statement.UseCache = false return session } -//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (session *Session) Join(join_operator string, tablename interface{}, condition string) *Session { - session.Statement.Join(join_operator, tablename, condition) +// Join join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (session *Session) Join(joinOperator string, tablename interface{}, condition string, args ...interface{}) *Session { + session.Statement.Join(joinOperator, tablename, condition, args...) return session } -// Generate Group By statement +// GroupBy Generate Group By statement func (session *Session) GroupBy(keys string) *Session { session.Statement.GroupBy(keys) return session } -// Generate Having statement +// Having Generate Having statement func (session *Session) Having(conditions string) *Session { session.Statement.Having(conditions) return session } +// DB db return the wrapper of sql.DB func (session *Session) DB() *core.DB { if session.db == nil { session.db = session.Engine.db @@ -357,7 +358,7 @@ func (session *Session) Begin() error { return nil } -// When using transaction, you can rollback if any error +// Rollback When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { session.saveLastSQL(session.Engine.dialect.RollBackStr()) @@ -367,7 +368,7 @@ func (session *Session) Rollback() error { return nil } -// When using transaction, Commit will commit all operations. +// Commit When using transaction, Commit will commit all operations. func (session *Session) Commit() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { session.saveLastSQL("COMMIT") @@ -471,7 +472,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b return nil } -//Execute sql +// Execute sql func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { if session.prepareStmt { stmt, err := session.doPrepare(sqlStr) @@ -521,7 +522,7 @@ func (session *Session) Exec(sqlStr string, args ...interface{}) (sql.Result, er return session.exec(sqlStr, args...) } -// this function create a table according a bean +// CreateTable create a table according a bean func (session *Session) CreateTable(bean interface{}) error { v := rValue(bean) session.Statement.RefTable = session.Engine.mapType(v) @@ -534,7 +535,7 @@ func (session *Session) CreateTable(bean interface{}) error { return session.createOneTable() } -// create indexes +// CreateIndexes create indexes func (session *Session) CreateIndexes(bean interface{}) error { v := rValue(bean) session.Statement.RefTable = session.Engine.mapType(v) @@ -554,7 +555,7 @@ func (session *Session) CreateIndexes(bean interface{}) error { return nil } -// create uniques +// CreateUniques create uniques func (session *Session) CreateUniques(bean interface{}) error { v := rValue(bean) session.Statement.RefTable = session.Engine.mapType(v) @@ -1242,7 +1243,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - var columnStr string = session.Statement.ColumnStr + var columnStr = session.Statement.ColumnStr if len(session.Statement.selectStr) > 0 { columnStr = session.Statement.selectStr } else { @@ -1265,11 +1266,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - session.Statement.Params = append(session.Statement.Params, session.Statement.BeanArgs...) + session.Statement.Params = append(session.Statement.joinArgs, append(session.Statement.Params, session.Statement.BeanArgs...)...) session.Statement.attachInSql() - sqlStr = session.Statement.genSelectSql(columnStr) + sqlStr = session.Statement.genSelectSQL(columnStr) args = session.Statement.Params // for mssql and use limit qs := strings.Count(sqlStr, "?") diff --git a/statement.go b/statement.go index 867c46c1..97cf6b48 100644 --- a/statement.go +++ b/statement.go @@ -37,7 +37,7 @@ type exprParam struct { expr string } -// statement save all the sql info for executing SQL +// Statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table Engine *Engine @@ -48,6 +48,7 @@ type Statement struct { Params []interface{} OrderStr string JoinStr string + joinArgs []interface{} GroupByStr string HavingStr string ColumnStr string @@ -91,6 +92,7 @@ func (statement *Statement) Init() { statement.OrderStr = "" statement.UseCascade = true statement.JoinStr = "" + statement.joinArgs = make([]interface{}, 0) statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnStr = "" @@ -428,12 +430,27 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, return colNames, args } +func (statement *Statement) needTableName() bool { + return len(statement.JoinStr) > 0 +} + +func (statement *Statement) colName(col *core.Column, tableName string) string { + if statement.needTableName() { + var nm = tableName + if len(statement.TableAlias) > 0 { + nm = statement.TableAlias + } + return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) + } + return statement.Engine.Quote(col.Name) +} + // Auto generating conditions according a struct func buildConditions(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) ([]string, []interface{}) { - colNames := make([]string, 0) + var colNames []string var args = make([]interface{}, 0) for _, col := range table.Columns() { if !includeVersion && col.IsVersion { @@ -960,7 +977,7 @@ func (statement *Statement) Asc(colNames ...string) *Statement { } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN -func (statement *Statement) Join(join_operator string, tablename interface{}, condition string) *Statement { +func (statement *Statement) Join(join_operator string, tablename interface{}, condition string, args ...interface{}) *Statement { var buf bytes.Buffer if len(statement.JoinStr) > 0 { fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, join_operator) @@ -1003,6 +1020,7 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co fmt.Fprintf(&buf, " ON %v", condition) statement.JoinStr = buf.String() + statement.joinArgs = args return statement } @@ -1140,9 +1158,10 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) if len(statement.selectStr) > 0 { columnStr = statement.selectStr } else { + // TODO: always generate column names, not use * even if join if len(statement.JoinStr) == 0 { if len(columnStr) == 0 { - if statement.GroupByStr != "" { + if len(statement.GroupByStr) > 0 { columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) } else { columnStr = statement.genColumnStr() @@ -1150,7 +1169,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) } } else { if len(columnStr) == 0 { - if statement.GroupByStr != "" { + if len(statement.GroupByStr) > 0 { columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) } else { columnStr = "*" @@ -1160,7 +1179,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) } statement.attachInSql() // !admpub! fix bug:Iterate func missing "... IN (...)" - return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...) + return statement.genSelectSQL(columnStr), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) } func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { @@ -1203,15 +1222,15 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} } // count(index fieldname) > count(0) > count(*) - var id string = "*" + var id = "*" if statement.Engine.Dialect().DBType() == "ql" { id = "" } statement.attachInSql() - return statement.genSelectSql(fmt.Sprintf("count(%v)", id)), append(statement.Params, statement.BeanArgs...) + return statement.genSelectSQL(fmt.Sprintf("count(%v)", id)), append(append(statement.joinArgs, statement.Params...), statement.BeanArgs...) } -func (statement *Statement) genSelectSql(columnStr string) (a string) { +func (statement *Statement) genSelectSQL(columnStr string) (a string) { var distinct string if statement.IsDistinct { distinct = "DISTINCT " @@ -1322,11 +1341,12 @@ func (statement *Statement) processIdParam() { if statement.IdParam != nil { if statement.Engine.dialect.DBType() != "ql" { for i, col := range statement.RefTable.PKColumns() { + var colName = statement.colName(col, statement.TableName()) if i < len(*(statement.IdParam)) { - statement.And(fmt.Sprintf("%v %s ?", statement.Engine.Quote(col.Name), + statement.And(fmt.Sprintf("%v %s ?", colName, statement.Engine.dialect.EqStr()), (*(statement.IdParam))[i]) } else { - statement.And(fmt.Sprintf("%v %s ?", statement.Engine.Quote(col.Name), + statement.And(fmt.Sprintf("%v %s ?", colName, statement.Engine.dialect.EqStr()), "") } } diff --git a/xorm.go b/xorm.go index 82ee2e70..91aa8bc7 100644 --- a/xorm.go +++ b/xorm.go @@ -5,7 +5,6 @@ package xorm import ( - "errors" "fmt" "os" "reflect" @@ -17,7 +16,8 @@ import ( ) const ( - Version string = "0.5.2.0324" + // Version show the xorm's version + Version string = "0.5.2.0329" ) func regDrvsNDialects() bool { @@ -49,13 +49,13 @@ func close(engine *Engine) { engine.Close() } -// new a db manager according to the parameter. Currently support four +// NewEngine new a db manager according to the parameter. Currently support four // drivers func NewEngine(driverName string, dataSourceName string) (*Engine, error) { regDrvsNDialects() driver := core.QueryDriver(driverName) if driver == nil { - return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) + return nil, fmt.Errorf("Unsupported driver name: %v", driverName) } uri, err := driver.Parse(driverName, dataSourceName) @@ -65,7 +65,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { dialect := core.QueryDialect(uri.DbType) if dialect == nil { - return nil, errors.New(fmt.Sprintf("Unsupported dialect type: %v", uri.DbType)) + return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DbType) } db, err := core.Open(driverName, dataSourceName) @@ -97,7 +97,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return engine, nil } -// clone an engine +// Clone clone an engine func (engine *Engine) Clone() (*Engine, error) { return NewEngine(engine.DriverName(), engine.DataSourceName()) }