diff --git a/VERSION b/VERSION index b91f308b..bcdbd306 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.6.0.0917Beta +xorm v0.6.0.0919Beta diff --git a/rows.go b/rows.go index 2ef8d986..e441e274 100644 --- a/rows.go +++ b/rows.go @@ -41,7 +41,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { } if rows.session.Statement.RawSQL == "" { - sqlStr, args = rows.session.Statement.genGetSql(bean) + sqlStr, args = rows.session.Statement.genGetSQL(bean) } else { sqlStr = rows.session.Statement.RawSQL args = rows.session.Statement.RawParams diff --git a/session.go b/session.go index a58b6417..e0391d23 100644 --- a/session.go +++ b/session.go @@ -695,7 +695,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf for _, filter := range session.Engine.dialect.Filters() { sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) } - newsql := session.Statement.convertIdSql(sqlStr) + newsql := session.Statement.convertIDSQL(sqlStr) if newsql == "" { return false, ErrCacheFailed } @@ -796,7 +796,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) } - newsql := session.Statement.convertIdSql(sqlStr) + newsql := session.Statement.convertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } @@ -1042,7 +1042,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { return false, ErrTableNotFound } session.Statement.Limit(1) - sqlStr, args = session.Statement.genGetSql(bean) + sqlStr, args = session.Statement.genGetSQL(bean) } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -1093,7 +1093,7 @@ func (session *Session) Count(bean interface{}) (int64, error) { var sqlStr string var args []interface{} if session.Statement.RawSQL == "" { - sqlStr, args = session.Statement.genCountSql(bean) + sqlStr, args = session.Statement.genCountSQL(bean) } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -1125,7 +1125,7 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error var sqlStr string var args []interface{} if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSql(bean, columnName) + sqlStr, args = session.Statement.genSumSQL(bean, columnName) } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -1157,7 +1157,7 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64 var sqlStr string var args []interface{} if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSql(bean, columnNames...) + sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -1189,7 +1189,7 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6 var sqlStr string var args []interface{} if len(session.Statement.RawSQL) == 0 { - sqlStr, args = session.Statement.genSumSql(bean, columnNames...) + sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) } else { sqlStr = session.Statement.RawSQL args = session.Statement.RawParams @@ -3693,7 +3693,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) } - newsql := session.Statement.convertIdSql(sqlStr) + newsql := session.Statement.convertIDSQL(sqlStr) if newsql == "" { return ErrCacheFailed } diff --git a/statement.go b/statement.go index 5536d9cf..745024ef 100644 --- a/statement.go +++ b/statement.go @@ -1064,9 +1064,19 @@ func (s *Statement) genDelIndexSQL() []string { return sqls } -func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { - statement.setRefValue(rValue(bean)) +func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { + quote := s.Engine.Quote + sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()), + col.String(s.Engine.dialect)) + return sql, []interface{}{} +} +func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { + return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) +} + +func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { var table = statement.RefTable var addedTableName = (len(statement.JoinStr) > 0) @@ -1075,12 +1085,18 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) var err error autoCond, err = statement.buildConds(table, bean, true, true, false, true, addedTableName) if err != nil { - panic(err) + return "", nil, err } } statement.processIdParam() + return builder.ToSQL(statement.cond.And(autoCond)) +} + +func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { + statement.setRefValue(rValue(bean)) + var columnStr = statement.ColumnStr if len(statement.selectStr) > 0 { columnStr = statement.selectStr @@ -1105,61 +1121,29 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) } } - inSQL, inArgs, _ := builder.ToSQL(statement.cond.And(autoCond)) + condSQL, condArgs, _ := statement.genConds(bean) - return statement.genSelectSQL(columnStr, inSQL), append(statement.joinArgs, inArgs...) + return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) } -func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { - quote := s.Engine.Quote - sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()), - col.String(s.Engine.dialect)) - return sql, []interface{}{} -} - -func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { - return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) -} - -func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { +func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) { statement.setRefValue(rValue(bean)) - var autoCond builder.Cond - if !statement.noAutoCondition { - var err error - var addedTableName = (len(statement.JoinStr) > 0) - autoCond, err = statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) - if err != nil { - panic(err) - } - } - - // count(index fieldname) > count(0) > count(*) - condSQL, condArgs, _ := builder.ToSQL(statement.cond.And(autoCond)) + condSQL, condArgs, _ := statement.genConds(bean) return statement.genSelectSQL("count(*)", condSQL), append(statement.joinArgs, condArgs...) } -func (statement *Statement) genSumSql(bean interface{}, columns ...string) (string, []interface{}) { +func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) { statement.setRefValue(rValue(bean)) - var addedTableName = (len(statement.JoinStr) > 0) - var autoCond builder.Cond - if !statement.noAutoCondition { - var err error - autoCond, err = statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) - if err != nil { - panic(err) - } - } - - condSQL, condArgs, _ := builder.ToSQL(statement.cond.And(autoCond)) - var sumStrs = make([]string, 0, len(columns)) for _, colName := range columns { sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) } + + condSQL, condArgs, _ := statement.genConds(bean) + return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...) } @@ -1262,19 +1246,21 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { } func (statement *Statement) processIdParam() { - if statement.IdParam != nil { - for i, col := range statement.RefTable.PKColumns() { - var colName = statement.colName(col, statement.TableName()) - if i < len(*(statement.IdParam)) { - statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]}) - } else { - statement.cond = statement.cond.And(builder.Eq{colName: ""}) - } + if statement.IdParam == nil { + return + } + + for i, col := range statement.RefTable.PKColumns() { + var colName = statement.colName(col, statement.TableName()) + if i < len(*(statement.IdParam)) { + statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]}) + } else { + statement.cond = statement.cond.And(builder.Eq{colName: ""}) } } } -func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string { +func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { var colnames = make([]string, len(cols)) for i, col := range cols { if includeTableName { @@ -1287,21 +1273,19 @@ func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bo return strings.Join(colnames, ", ") } -func (statement *Statement) convertIdSql(sqlStr string) string { +func (statement *Statement) convertIDSQL(sqlStr string) string { if statement.RefTable != nil { cols := statement.RefTable.PKColumns() if len(cols) == 0 { return "" } - colstrs := statement.JoinColumns(cols, false) + colstrs := statement.joinColumns(cols, false) sqls := splitNNoCase(sqlStr, " from ", 2) if len(sqls) != 2 { return "" } - if statement.Engine.dialect.DBType() == "ql" { - return fmt.Sprintf("SELECT id() FROM %v", sqls[1]) - } + return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) } return "" @@ -1312,7 +1296,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { return "", "" } - colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true) + colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { diff --git a/xorm.go b/xorm.go index d0edf7cd..237d34c2 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( const ( // Version show the xorm's version - Version string = "0.6.0.0917Beta" + Version string = "0.6.0.0919Beta" ) func regDrvsNDialects() bool {