fix bug and catch more tests (#613)

This commit is contained in:
Lunny Xiao 2017-06-08 19:38:52 +08:00 committed by GitHub
parent 7a9bf19c65
commit 1e055bac01
10 changed files with 144 additions and 43 deletions

View File

@ -33,8 +33,9 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if err := rows.session.Statement.setRefValue(rValue(bean)); err != nil { if err = rows.session.Statement.setRefValue(rValue(bean)); err != nil {
return nil, err return nil, err
} }
@ -43,7 +44,10 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
} }
if rows.session.Statement.RawSQL == "" { if rows.session.Statement.RawSQL == "" {
sqlStr, args = rows.session.Statement.genGetSQL(bean) sqlStr, args, err = rows.session.Statement.genGetSQL(bean)
if err != nil {
return nil, err
}
} else { } else {
sqlStr = rows.session.Statement.RawSQL sqlStr = rows.session.Statement.RawSQL
args = rows.session.Statement.RawParams args = rows.session.Statement.RawParams
@ -54,7 +58,6 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
} }
rows.session.saveLastSQL(sqlStr, args...) rows.session.saveLastSQL(sqlStr, args...)
var err error
if rows.session.prepareStmt { if rows.session.prepareStmt {
rows.stmt, err = rows.session.DB().Prepare(sqlStr) rows.stmt, err = rows.session.DB().Prepare(sqlStr)
if err != nil { if err != nil {

View File

@ -48,6 +48,8 @@ type Session struct {
//beforeSQLExec func(string, ...interface{}) //beforeSQLExec func(string, ...interface{})
lastSQL string lastSQL string
lastSQLArgs []interface{} lastSQLArgs []interface{}
err error
} }
// Clone copy all the session's content and return a new session // Clone copy all the session's content and return a new session
@ -620,7 +622,7 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i
structInter := reflect.New(fieldValue.Type()) structInter := reflect.New(fieldValue.Type())
newsession := session.Engine.NewSession() newsession := session.Engine.NewSession()
defer newsession.Close() defer newsession.Close()
has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) has, err := newsession.ID(pk).NoCascade().Get(structInter.Interface())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -98,8 +98,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
processor.BeforeDelete() processor.BeforeDelete()
} }
// -- condSQL, condArgs, err := session.Statement.genConds(bean)
condSQL, condArgs, _ := session.Statement.genConds(bean) if err != nil {
return 0, err
}
if len(condSQL) == 0 && session.Statement.LimitN == 0 { if len(condSQL) == 0 && session.Statement.LimitN == 0 {
return 0, ErrNeedDeletedCond return 0, ErrNeedDeletedCond
} }

View File

@ -15,7 +15,7 @@ func TestDelete(t *testing.T) {
assert.NoError(t, prepareEngine()) assert.NoError(t, prepareEngine())
type UserinfoDelete struct { type UserinfoDelete struct {
Uid int64 Uid int64 `xorm:"id pk not null autoincr"`
IsMan bool IsMan bool
} }

View File

@ -91,6 +91,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 { if len(session.Statement.TableName()) <= 0 {
return ErrTableNotFound return ErrTableNotFound
@ -128,7 +129,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
args = append(session.Statement.joinArgs, condArgs...) args = append(session.Statement.joinArgs, condArgs...)
sqlStr = session.Statement.genSelectSQL(columnStr, condSQL) sqlStr, err = session.Statement.genSelectSQL(columnStr, condSQL)
if err != nil {
return err
}
// for mssql and use limit // for mssql and use limit
qs := strings.Count(sqlStr, "?") qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs { if len(args)*2 == qs {
@ -139,7 +143,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
args = session.Statement.RawParams args = session.Statement.RawParams
} }
var err error
if session.canCache() { if session.canCache() {
if cacher := session.Engine.getCacher2(table); cacher != nil && if cacher := session.Engine.getCacher2(table); cacher != nil &&
!session.Statement.IsDistinct && !session.Statement.IsDistinct &&

View File

@ -33,13 +33,17 @@ func (session *Session) Get(bean interface{}) (bool, error) {
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
if len(session.Statement.TableName()) <= 0 { if len(session.Statement.TableName()) <= 0 {
return false, ErrTableNotFound return false, ErrTableNotFound
} }
session.Statement.Limit(1) session.Statement.Limit(1)
sqlStr, args = session.Statement.genGetSQL(bean) sqlStr, args, err = session.Statement.genGetSQL(bean)
if err != nil {
return false, err
}
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams

View File

@ -1139,3 +1139,29 @@ func TestCompositePK(t *testing.T) {
assert.EqualValues(t, "uid", pkCols[0].Name) assert.EqualValues(t, "uid", pkCols[0].Name)
assert.EqualValues(t, "tid", pkCols[1].Name) assert.EqualValues(t, "tid", pkCols[1].Name)
} }
func TestNoPKIdQueryUpdate(t *testing.T) {
type NoPKTable struct {
Username string
}
assert.NoError(t, prepareEngine())
assertSync(t, new(NoPKTable))
cnt, err := testEngine.Insert(&NoPKTable{
Username: "test",
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
var res NoPKTable
has, err := testEngine.ID("test").Get(&res)
assert.Error(t, err)
assert.False(t, has)
cnt, err = testEngine.ID("test").Update(&NoPKTable{
Username: "test1",
})
assert.Error(t, err)
assert.EqualValues(t, 0, cnt)
}

View File

@ -16,11 +16,15 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if session.Statement.RawSQL == "" { if session.Statement.RawSQL == "" {
if len(bean) == 0 { if len(bean) == 0 {
return 0, ErrTableNotFound return 0, ErrTableNotFound
} }
sqlStr, args = session.Statement.genCountSQL(bean[0]) sqlStr, args, err = session.Statement.genCountSQL(bean[0])
if err != nil {
return 0, err
}
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
@ -28,7 +32,6 @@ func (session *Session) Count(bean ...interface{}) (int64, error) {
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
var err error
var total int64 var total int64
if session.IsAutoCommit { if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&total) err = session.DB().QueryRow(sqlStr, args...).Scan(&total)
@ -52,8 +55,12 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 { if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSQL(bean, columnName) sqlStr, args, err = session.Statement.genSumSQL(bean, columnName)
if err != nil {
return 0, err
}
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
@ -61,7 +68,6 @@ func (session *Session) Sum(bean interface{}, columnName string) (float64, error
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
var err error
var res float64 var res float64
if session.IsAutoCommit { if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).Scan(&res) err = session.DB().QueryRow(sqlStr, args...).Scan(&res)
@ -84,8 +90,12 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 { if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
if err != nil {
return nil, err
}
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
@ -93,7 +103,6 @@ func (session *Session) Sums(bean interface{}, columnNames ...string) ([]float64
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
var err error
var res = make([]float64, len(columnNames), len(columnNames)) var res = make([]float64, len(columnNames), len(columnNames))
if session.IsAutoCommit { if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)
@ -116,8 +125,12 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
var sqlStr string var sqlStr string
var args []interface{} var args []interface{}
var err error
if len(session.Statement.RawSQL) == 0 { if len(session.Statement.RawSQL) == 0 {
sqlStr, args = session.Statement.genSumSQL(bean, columnNames...) sqlStr, args, err = session.Statement.genSumSQL(bean, columnNames...)
if err != nil {
return nil, err
}
} else { } else {
sqlStr = session.Statement.RawSQL sqlStr = session.Statement.RawSQL
args = session.Statement.RawParams args = session.Statement.RawParams
@ -125,7 +138,6 @@ func (session *Session) SumsInt(bean interface{}, columnNames ...string) ([]int6
session.queryPreprocess(&sqlStr, args...) session.queryPreprocess(&sqlStr, args...)
var err error
var res = make([]int64, len(columnNames), len(columnNames)) var res = make([]int64, len(columnNames), len(columnNames))
if session.IsAutoCommit { if session.IsAutoCommit {
err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res) err = session.DB().QueryRow(sqlStr, args...).ScanSlice(&res)

View File

@ -236,7 +236,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr) colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+v.expr)
} }
session.Statement.processIDParam() if err = session.Statement.processIDParam(); err != nil {
return 0, err
}
var autoCond builder.Cond var autoCond builder.Cond
if !session.Statement.noAutoCondition && len(condiBean) > 0 { if !session.Statement.noAutoCondition && len(condiBean) > 0 {
@ -267,7 +269,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1") colNames = append(colNames, session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1")
} }
condSQL, condArgs, _ = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
@ -285,7 +291,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, _ = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
@ -293,7 +302,11 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN) tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)", cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...)) session.Engine.Quote(session.Statement.TableName()), tempCondSQL), condArgs...))
condSQL, condArgs, _ = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }
@ -304,7 +317,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0], table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...) session.Engine.Quote(session.Statement.TableName()), condSQL), condArgs...)
condSQL, condArgs, _ = builder.ToSQL(cond) condSQL, condArgs, err = builder.ToSQL(cond)
if err != nil {
return 0, err
}
if len(condSQL) > 0 { if len(condSQL) > 0 {
condSQL = "WHERE " + condSQL condSQL = "WHERE " + condSQL
} }

View File

@ -1118,12 +1118,14 @@ func (statement *Statement) genConds(bean interface{}) (string, []interface{}, e
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
} }
statement.processIDParam() if err := statement.processIDParam(); err != nil {
return "", nil, err
}
return builder.ToSQL(statement.cond) return builder.ToSQL(statement.cond)
} }
func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean) v := rValue(bean)
isStruct := v.Kind() == reflect.Struct isStruct := v.Kind() == reflect.Struct
if isStruct { if isStruct {
@ -1158,19 +1160,31 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{})
var condSQL string var condSQL string
var condArgs []interface{} var condArgs []interface{}
var err error
if isStruct { if isStruct {
condSQL, condArgs, _ = statement.genConds(bean) condSQL, condArgs, err = statement.genConds(bean)
} else { } else {
condSQL, condArgs, _ = builder.ToSQL(statement.cond) condSQL, condArgs, err = builder.ToSQL(statement.cond)
}
if err != nil {
return "", nil, err
} }
return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) sqlStr, err := statement.genSelectSQL(columnStr, condSQL)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) { func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}, error) {
statement.setRefValue(rValue(bean)) statement.setRefValue(rValue(bean))
condSQL, condArgs, _ := statement.genConds(bean) condSQL, condArgs, err := statement.genConds(bean)
if err != nil {
return "", nil, err
}
var selectSQL = statement.selectStr var selectSQL = statement.selectStr
if len(selectSQL) <= 0 { if len(selectSQL) <= 0 {
@ -1180,10 +1194,15 @@ func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}
selectSQL = "count(*)" selectSQL = "count(*)"
} }
} }
return statement.genSelectSQL(selectSQL, condSQL), append(statement.joinArgs, condArgs...) sqlStr, err := statement.genSelectSQL(selectSQL, condSQL)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) { func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
statement.setRefValue(rValue(bean)) statement.setRefValue(rValue(bean))
var sumStrs = make([]string, 0, len(columns)) var sumStrs = make([]string, 0, len(columns))
@ -1195,12 +1214,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri
} }
sumSelect := strings.Join(sumStrs, ", ") sumSelect := strings.Join(sumStrs, ", ")
condSQL, condArgs, _ := statement.genConds(bean) condSQL, condArgs, err := statement.genConds(bean)
if err != nil {
return "", nil, err
}
return statement.genSelectSQL(sumSelect, condSQL), append(statement.joinArgs, condArgs...) sqlStr, err := statement.genSelectSQL(sumSelect, condSQL)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) {
var distinct string var distinct string
if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
distinct = "DISTINCT " distinct = "DISTINCT "
@ -1211,7 +1238,9 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
var top string var top string
var mssqlCondi string var mssqlCondi string
statement.processIDParam() if err := statement.processIDParam(); err != nil {
return "", err
}
var buf bytes.Buffer var buf bytes.Buffer
if len(condSQL) > 0 { if len(condSQL) > 0 {
@ -1314,19 +1343,23 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) {
return return
} }
func (statement *Statement) processIDParam() { func (statement *Statement) processIDParam() error {
if statement.idParam == nil { if statement.idParam == nil {
return return nil
}
if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
len(statement.RefTable.PrimaryKeys),
len(*statement.idParam),
)
} }
for i, col := range statement.RefTable.PKColumns() { for i, col := range statement.RefTable.PKColumns() {
var colName = statement.colName(col, statement.TableName()) var colName = statement.colName(col, statement.TableName())
if i < len(*(statement.idParam)) { statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
} else {
statement.cond = statement.cond.And(builder.Eq{colName: ""})
}
} }
return nil
} }
func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {