This commit is contained in:
Lunny Xiao 2016-09-19 11:13:40 +08:00
parent 1b773e8762
commit c9b09da6e1
5 changed files with 53 additions and 69 deletions

View File

@ -1 +1 @@
xorm v0.6.0.0917Beta
xorm v0.6.0.0919Beta

View File

@ -41,7 +41,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
}
if rows.session.Statement.RawSQL == "" {
sqlStr, args = rows.session.Statement.genGetSql(bean)
sqlStr, args = rows.session.Statement.genGetSQL(bean)
} else {
sqlStr = rows.session.Statement.RawSQL
args = rows.session.Statement.RawParams

View File

@ -695,7 +695,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
for _, filter := range session.Engine.dialect.Filters() {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
newsql := session.Statement.convertIdSql(sqlStr)
newsql := session.Statement.convertIDSQL(sqlStr)
if newsql == "" {
return false, ErrCacheFailed
}
@ -796,7 +796,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
newsql := session.Statement.convertIdSql(sqlStr)
newsql := session.Statement.convertIDSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}
@ -1042,7 +1042,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
return false, ErrTableNotFound
}
session.Statement.Limit(1)
sqlStr, args = session.Statement.genGetSql(bean)
sqlStr, args = session.Statement.genGetSQL(bean)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
@ -1093,7 +1093,7 @@ func (session *Session) Count(bean interface{}) (int64, error) {
var sqlStr string
var args []interface{}
if session.Statement.RawSQL == "" {
sqlStr, args = session.Statement.genCountSql(bean)
sqlStr, args = session.Statement.genCountSQL(bean)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
@ -1125,7 +1125,7 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
var sqlStr string
var args []interface{}
if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSql(bean, columnName)
sqlStr, args = session.Statement.genSumSQL(bean, columnName)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
@ -1157,7 +1157,7 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
var sqlStr string
var args []interface{}
if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSql(bean, columnNames...)
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
@ -1189,7 +1189,7 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
var sqlStr string
var args []interface{}
if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSql(bean, columnNames...)
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...)
} else {
sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams
@ -3693,7 +3693,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
}
newsql := session.Statement.convertIdSql(sqlStr)
newsql := session.Statement.convertIDSQL(sqlStr)
if newsql == "" {
return ErrCacheFailed
}

View File

@ -1064,9 +1064,19 @@ func (s *Statement) genDelIndexSQL() []string {
return sqls
}
func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) {
statement.setRefValue(rValue(bean))
func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := s.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()),
col.String(s.Engine.dialect))
return sql, []interface{}{}
}
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
}
func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
var table = statement.RefTable
var addedTableName = (len(statement.JoinStr) > 0)
@ -1075,12 +1085,18 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
var err error
autoCond, err = statement.buildConds(table, bean, true, true, false, true, addedTableName)
if err != nil {
panic(err)
return "", nil, err
}
}
statement.processIdParam()
return builder.ToSQL(statement.cond.And(autoCond))
}
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) {
statement.setRefValue(rValue(bean))
var columnStr = statement.ColumnStr
if len(statement.selectStr) > 0 {
columnStr = statement.selectStr
@ -1105,61 +1121,29 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
}
}
inSQL, inArgs, _ := builder.ToSQL(statement.cond.And(autoCond))
condSQL, condArgs, _ := statement.genConds(bean)
return statement.genSelectSQL(columnStr, inSQL), append(statement.joinArgs, inArgs...)
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...)
}
func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := s.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()),
col.String(s.Engine.dialect))
return sql, []interface{}{}
}
func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
}
func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) {
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) {
statement.setRefValue(rValue(bean))
var autoCond builder.Cond
if !statement.noAutoCondition {
var err error
var addedTableName = (len(statement.JoinStr) > 0)
autoCond, err = statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil {
panic(err)
}
}
// count(index fieldname) > count(0) > count(*)
condSQL, condArgs, _ := builder.ToSQL(statement.cond.And(autoCond))
condSQL, condArgs, _ := statement.genConds(bean)
return statement.genSelectSQL("count(*)", condSQL), append(statement.joinArgs, condArgs...)
}
func (statement *Statement) genSumSql(bean interface{}, columns ...string) (string, []interface{}) {
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) {
statement.setRefValue(rValue(bean))
var addedTableName = (len(statement.JoinStr) > 0)
var autoCond builder.Cond
if !statement.noAutoCondition {
var err error
autoCond, err = statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil {
panic(err)
}
}
condSQL, condArgs, _ := builder.ToSQL(statement.cond.And(autoCond))
var sumStrs = make([]string, 0, len(columns))
for _, colName := range columns {
sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
}
condSQL, condArgs, _ := statement.genConds(bean)
return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...)
}
@ -1262,19 +1246,21 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
}
func (statement *Statement) processIdParam() {
if statement.IdParam != nil {
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
if i < len(*(statement.IdParam)) {
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]})
} else {
statement.cond = statement.cond.And(builder.Eq{colName: ""})
}
if statement.IdParam == nil {
return
}
for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName())
if i < len(*(statement.IdParam)) {
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]})
} else {
statement.cond = statement.cond.And(builder.Eq{colName: ""})
}
}
}
func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string {
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
var colnames = make([]string, len(cols))
for i, col := range cols {
if includeTableName {
@ -1287,21 +1273,19 @@ func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bo
return strings.Join(colnames, ", ")
}
func (statement *Statement) convertIdSql(sqlStr string) string {
func (statement *Statement) convertIDSQL(sqlStr string) string {
if statement.RefTable != nil {
cols := statement.RefTable.PKColumns()
if len(cols) == 0 {
return ""
}
colstrs := statement.JoinColumns(cols, false)
colstrs := statement.joinColumns(cols, false)
sqls := splitNNoCase(sqlStr, " from ", 2)
if len(sqls) != 2 {
return ""
}
if statement.Engine.dialect.DBType() == "ql" {
return fmt.Sprintf("SELECT id() FROM %v", sqls[1])
}
return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1])
}
return ""
@ -1312,7 +1296,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
return "", ""
}
colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true)
colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
sqls := splitNNoCase(sqlStr, "where", 2)
if len(sqls) != 2 {
if len(sqls) == 1 {

View File

@ -17,7 +17,7 @@ import (
const (
// Version show the xorm's version
Version string = "0.6.0.0917Beta"
Version string = "0.6.0.0919Beta"
)
func regDrvsNDialects() bool {