Merge branch 'master' into fix_batch_insert_interface_slice
This commit is contained in:
commit
e15f07c284
|
@ -901,7 +901,7 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
|
func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string {
|
||||||
if len(db.Schema) == 0 {
|
if len(db.Schema) == 0 || strings.Contains(tableName, ".") {
|
||||||
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
|
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s",
|
||||||
tableName, col.Name, db.SqlType(col))
|
tableName, col.Name, db.SqlType(col))
|
||||||
}
|
}
|
||||||
|
@ -913,8 +913,8 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string {
|
||||||
quote := db.Quote
|
quote := db.Quote
|
||||||
idxName := index.Name
|
idxName := index.Name
|
||||||
|
|
||||||
tableName = strings.Replace(tableName, `"`, "", -1)
|
tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".")
|
||||||
tableName = strings.Replace(tableName, `.`, "_", -1)
|
tableName = tableParts[len(tableParts)-1]
|
||||||
|
|
||||||
if !strings.HasPrefix(idxName, "UQE_") &&
|
if !strings.HasPrefix(idxName, "UQE_") &&
|
||||||
!strings.HasPrefix(idxName, "IDX_") {
|
!strings.HasPrefix(idxName, "IDX_") {
|
||||||
|
|
11
helpers.go
11
helpers.go
|
@ -155,6 +155,17 @@ func isZero(k interface{}) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isZeroValue(v reflect.Value) bool {
|
||||||
|
if isZero(v.Interface()) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||||
|
return v.IsNil()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func isStructZero(v reflect.Value) bool {
|
func isStructZero(v reflect.Value) bool {
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -92,6 +92,7 @@ type EngineInterface interface {
|
||||||
Quote(string) string
|
Quote(string) string
|
||||||
SetCacher(string, core.Cacher)
|
SetCacher(string, core.Cacher)
|
||||||
SetConnMaxLifetime(time.Duration)
|
SetConnMaxLifetime(time.Duration)
|
||||||
|
SetColumnMapper(core.IMapper)
|
||||||
SetDefaultCacher(core.Cacher)
|
SetDefaultCacher(core.Cacher)
|
||||||
SetLogger(logger core.ILogger)
|
SetLogger(logger core.ILogger)
|
||||||
SetLogLevel(core.LogLevel)
|
SetLogLevel(core.LogLevel)
|
||||||
|
@ -99,6 +100,7 @@ type EngineInterface interface {
|
||||||
SetMaxOpenConns(int)
|
SetMaxOpenConns(int)
|
||||||
SetMaxIdleConns(int)
|
SetMaxIdleConns(int)
|
||||||
SetSchema(string)
|
SetSchema(string)
|
||||||
|
SetTableMapper(core.IMapper)
|
||||||
SetTZDatabase(tz *time.Location)
|
SetTZDatabase(tz *time.Location)
|
||||||
SetTZLocation(tz *time.Location)
|
SetTZLocation(tz *time.Location)
|
||||||
ShowExecTime(...bool)
|
ShowExecTime(...bool)
|
||||||
|
|
|
@ -101,7 +101,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
if len(condSQL) == 0 && session.statement.LimitN == 0 {
|
pLimitN := session.statement.LimitN
|
||||||
|
if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) {
|
||||||
return 0, ErrNeedDeletedCond
|
return 0, ErrNeedDeletedCond
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,8 +120,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
||||||
if len(session.statement.OrderStr) > 0 {
|
if len(session.statement.OrderStr) > 0 {
|
||||||
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
|
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
|
||||||
}
|
}
|
||||||
if session.statement.LimitN > 0 {
|
if pLimitN != nil && *pLimitN > 0 {
|
||||||
orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN)
|
limitNValue := *pLimitN
|
||||||
|
orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(orderSQL) > 0 {
|
if len(orderSQL) > 0 {
|
||||||
|
|
|
@ -680,7 +680,7 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac
|
||||||
|
|
||||||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
|
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
|
||||||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
|
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
|
||||||
if col.Nullable && isZero(fieldValue.Interface()) {
|
if col.Nullable && isZeroValue(fieldValue) {
|
||||||
var nilValue *int
|
var nilValue *int
|
||||||
fieldValue = reflect.ValueOf(nilValue)
|
fieldValue = reflect.ValueOf(nilValue)
|
||||||
}
|
}
|
||||||
|
@ -735,66 +735,7 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
|
||||||
args = append(args, m[colName])
|
args = append(args, m[colName])
|
||||||
}
|
}
|
||||||
|
|
||||||
w := builder.NewWriter()
|
return session.insertMap(columns, args)
|
||||||
if session.statement.cond.IsValid() {
|
|
||||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.WriteString(") SELECT "); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := session.statement.writeArgs(w, args); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(exprs.args) > 0 {
|
|
||||||
if _, err := w.WriteString(","); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if err := exprs.writeArgs(w); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := w.WriteString(fmt.Sprintf(" FROM %s WHERE ", session.engine.Quote(tableName))); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := session.statement.cond.WriteTo(w); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
qm := strings.Repeat("?,", len(columns))
|
|
||||||
qm = qm[:len(qm)-1]
|
|
||||||
|
|
||||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
w.Append(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
sql := w.String()
|
|
||||||
args = w.Args()
|
|
||||||
|
|
||||||
if err := session.cacheInsert(tableName); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := session.exec(sql, args...)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
affected, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return affected, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
||||||
|
@ -814,6 +755,7 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
||||||
columns = append(columns, k)
|
columns = append(columns, k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(columns)
|
sort.Strings(columns)
|
||||||
|
|
||||||
var args = make([]interface{}, 0, len(m))
|
var args = make([]interface{}, 0, len(m))
|
||||||
|
@ -821,7 +763,18 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
||||||
args = append(args, m[colName])
|
args = append(args, m[colName])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return session.insertMap(columns, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
|
||||||
|
tableName := session.statement.TableName()
|
||||||
|
if len(tableName) <= 0 {
|
||||||
|
return 0, ErrTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
exprs := session.statement.exprColumns
|
||||||
w := builder.NewWriter()
|
w := builder.NewWriter()
|
||||||
|
// if insert where
|
||||||
if session.statement.cond.IsValid() {
|
if session.statement.cond.IsValid() {
|
||||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
|
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -859,10 +812,29 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
||||||
qm := strings.Repeat("?,", len(columns))
|
qm := strings.Repeat("?,", len(columns))
|
||||||
qm = qm[:len(qm)-1]
|
qm = qm[:len(qm)-1]
|
||||||
|
|
||||||
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)); err != nil {
|
if _, err := w.WriteString(fmt.Sprintf("INSERT INTO %s (", session.engine.Quote(tableName))); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := writeStrings(w, append(columns, exprs.colNames...), "`", "`"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if _, err := w.WriteString(fmt.Sprintf(") VALUES (%s", qm)); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
w.Append(args...)
|
w.Append(args...)
|
||||||
|
if len(exprs.args) > 0 {
|
||||||
|
if _, err := w.WriteString(","); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if err := exprs.writeArgs(w); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := w.WriteString(")"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sql := w.String()
|
sql := w.String()
|
||||||
|
|
|
@ -952,6 +952,64 @@ func TestInsertWhere(t *testing.T) {
|
||||||
assert.EqualValues(t, 5, j5.Index)
|
assert.EqualValues(t, 5, j5.Index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInsertExpr2(t *testing.T) {
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
|
||||||
|
type InsertExprsRelease struct {
|
||||||
|
Id int64
|
||||||
|
RepoId int
|
||||||
|
IsTag bool
|
||||||
|
IsDraft bool
|
||||||
|
NumCommits int
|
||||||
|
Sha1 string
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSync(t, new(InsertExprsRelease))
|
||||||
|
|
||||||
|
var ie = InsertExprsRelease{
|
||||||
|
RepoId: 1,
|
||||||
|
IsTag: true,
|
||||||
|
}
|
||||||
|
inserted, err := testEngine.
|
||||||
|
SetExpr("is_draft", true).
|
||||||
|
SetExpr("num_commits", 0).
|
||||||
|
SetExpr("sha1", "").
|
||||||
|
Insert(&ie)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, inserted)
|
||||||
|
|
||||||
|
var ie2 InsertExprsRelease
|
||||||
|
has, err := testEngine.ID(ie.Id).Get(&ie2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
assert.EqualValues(t, true, ie2.IsDraft)
|
||||||
|
assert.EqualValues(t, "", ie2.Sha1)
|
||||||
|
assert.EqualValues(t, 0, ie2.NumCommits)
|
||||||
|
assert.EqualValues(t, 1, ie2.RepoId)
|
||||||
|
assert.EqualValues(t, true, ie2.IsTag)
|
||||||
|
|
||||||
|
inserted, err = testEngine.Table(new(InsertExprsRelease)).
|
||||||
|
SetExpr("is_draft", true).
|
||||||
|
SetExpr("num_commits", 0).
|
||||||
|
SetExpr("sha1", "").
|
||||||
|
Insert(map[string]interface{}{
|
||||||
|
"repo_id": 1,
|
||||||
|
"is_tag": true,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, inserted)
|
||||||
|
|
||||||
|
var ie3 InsertExprsRelease
|
||||||
|
has, err = testEngine.ID(ie.Id + 1).Get(&ie3)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
assert.EqualValues(t, true, ie3.IsDraft)
|
||||||
|
assert.EqualValues(t, "", ie3.Sha1)
|
||||||
|
assert.EqualValues(t, 0, ie3.NumCommits)
|
||||||
|
assert.EqualValues(t, 1, ie3.RepoId)
|
||||||
|
assert.EqualValues(t, true, ie3.IsTag)
|
||||||
|
}
|
||||||
|
|
||||||
type NightlyRate struct {
|
type NightlyRate struct {
|
||||||
ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"`
|
ID int64 `xorm:"'id' not null pk BIGINT(20)" json:"id"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
|
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import "reflect"
|
import (
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
// IterFunc only use by Iterate
|
// IterFunc only use by Iterate
|
||||||
type IterFunc func(idx int, bean interface{}) error
|
type IterFunc func(idx int, bean interface{}) error
|
||||||
|
@ -60,22 +62,23 @@ func (session *Session) BufferSize(size int) *Session {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
|
func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
|
||||||
if session.isAutoClose {
|
|
||||||
defer session.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
var bufferSize = session.statement.bufferSize
|
var bufferSize = session.statement.bufferSize
|
||||||
var limit = session.statement.LimitN
|
var pLimitN = session.statement.LimitN
|
||||||
if limit > 0 && bufferSize > limit {
|
if pLimitN != nil && bufferSize > *pLimitN {
|
||||||
bufferSize = limit
|
bufferSize = *pLimitN
|
||||||
}
|
}
|
||||||
var start = session.statement.Start
|
var start = session.statement.Start
|
||||||
v := rValue(bean)
|
v := rValue(bean)
|
||||||
sliceType := reflect.SliceOf(v.Type())
|
sliceType := reflect.SliceOf(v.Type())
|
||||||
var idx = 0
|
var idx = 0
|
||||||
for {
|
session.autoResetStatement = false
|
||||||
|
defer func() {
|
||||||
|
session.autoResetStatement = true
|
||||||
|
}()
|
||||||
|
|
||||||
|
for bufferSize > 0 {
|
||||||
slice := reflect.New(sliceType)
|
slice := reflect.New(sliceType)
|
||||||
if err := session.Limit(bufferSize, start).find(slice.Interface(), bean); err != nil {
|
if err := session.NoCache().Limit(bufferSize, start).find(slice.Interface(), bean); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,13 +89,13 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
|
||||||
idx++
|
idx++
|
||||||
}
|
}
|
||||||
|
|
||||||
start = start + slice.Elem().Len()
|
if bufferSize > slice.Elem().Len() {
|
||||||
if limit > 0 && idx+bufferSize > limit {
|
break
|
||||||
bufferSize = limit - idx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if bufferSize <= 0 || slice.Elem().Len() < bufferSize || idx == limit {
|
start = start + slice.Elem().Len()
|
||||||
break
|
if pLimitN != nil && start+bufferSize > *pLimitN {
|
||||||
|
bufferSize = *pLimitN - start
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -89,4 +89,15 @@ func TestBufferIterate(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 7, cnt)
|
assert.EqualValues(t, 7, cnt)
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
err = testEngine.Where("id <= 10").BufferSize(2).Iterate(new(UserBufferIterate), func(i int, bean interface{}) error {
|
||||||
|
user := bean.(*UserBufferIterate)
|
||||||
|
assert.EqualValues(t, cnt+1, user.Id)
|
||||||
|
assert.EqualValues(t, true, user.IsMan)
|
||||||
|
cnt++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 10, cnt)
|
||||||
}
|
}
|
||||||
|
|
|
@ -239,6 +239,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
for i, colName := range exprColumns.colNames {
|
for i, colName := range exprColumns.colNames {
|
||||||
switch tp := exprColumns.args[i].(type) {
|
switch tp := exprColumns.args[i].(type) {
|
||||||
case string:
|
case string:
|
||||||
|
if len(tp) == 0 {
|
||||||
|
tp = "''"
|
||||||
|
}
|
||||||
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)
|
colNames = append(colNames, session.engine.Quote(colName)+"="+tp)
|
||||||
case *builder.Builder:
|
case *builder.Builder:
|
||||||
subQuery, subArgs, err := builder.ToSQL(tp)
|
subQuery, subArgs, err := builder.ToSQL(tp)
|
||||||
|
@ -247,6 +250,9 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
}
|
}
|
||||||
colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")")
|
colNames = append(colNames, session.engine.Quote(colName)+"=("+subQuery+")")
|
||||||
args = append(args, subArgs...)
|
args = append(args, subArgs...)
|
||||||
|
default:
|
||||||
|
colNames = append(colNames, session.engine.Quote(colName)+"=?")
|
||||||
|
args = append(args, exprColumns.args[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,22 +300,26 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
|
|
||||||
st := &session.statement
|
st := &session.statement
|
||||||
|
|
||||||
var sqlStr string
|
var (
|
||||||
var condArgs []interface{}
|
sqlStr string
|
||||||
var condSQL string
|
condArgs []interface{}
|
||||||
cond := session.statement.cond.And(autoCond)
|
condSQL string
|
||||||
|
cond = session.statement.cond.And(autoCond)
|
||||||
|
|
||||||
var doIncVer = (table != nil && table.Version != "" && session.statement.checkVersion)
|
doIncVer = isStruct && (table != nil && table.Version != "" && session.statement.checkVersion)
|
||||||
var verValue *reflect.Value
|
verValue *reflect.Value
|
||||||
|
)
|
||||||
if doIncVer {
|
if doIncVer {
|
||||||
verValue, err = table.VersionColumn().ValueOf(bean)
|
verValue, err = table.VersionColumn().ValueOf(bean)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if verValue != nil {
|
||||||
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
cond = cond.And(builder.Eq{session.engine.Quote(table.Version): verValue.Interface()})
|
||||||
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
|
colNames = append(colNames, session.engine.Quote(table.Version)+" = "+session.engine.Quote(table.Version)+" + 1")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -327,11 +337,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
var tableName = session.statement.TableName()
|
var tableName = session.statement.TableName()
|
||||||
// TODO: Oracle support needed
|
// TODO: Oracle support needed
|
||||||
var top string
|
var top string
|
||||||
if st.LimitN > 0 {
|
if st.LimitN != nil {
|
||||||
|
limitValue := *st.LimitN
|
||||||
if st.Engine.dialect.DBType() == core.MYSQL {
|
if st.Engine.dialect.DBType() == core.MYSQL {
|
||||||
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||||
} else if st.Engine.dialect.DBType() == core.SQLITE {
|
} else if st.Engine.dialect.DBType() == core.SQLITE {
|
||||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
|
||||||
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
@ -342,7 +353,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
} else if st.Engine.dialect.DBType() == core.POSTGRES {
|
} else if st.Engine.dialect.DBType() == core.POSTGRES {
|
||||||
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
|
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
|
||||||
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
|
||||||
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
session.engine.Quote(tableName), tempCondSQL), condArgs...))
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
@ -357,7 +368,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
|
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
|
||||||
table != nil && len(table.PrimaryKeys) == 1 {
|
table != nil && len(table.PrimaryKeys) == 1 {
|
||||||
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
|
||||||
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
|
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
|
||||||
session.engine.Quote(tableName), condSQL), condArgs...)
|
session.engine.Quote(tableName), condSQL), condArgs...)
|
||||||
|
|
||||||
condSQL, condArgs, err = builder.ToSQL(cond)
|
condSQL, condArgs, err = builder.ToSQL(cond)
|
||||||
|
@ -368,7 +379,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
||||||
condSQL = "WHERE " + condSQL
|
condSQL = "WHERE " + condSQL
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
top = fmt.Sprintf("TOP (%d) ", st.LimitN)
|
top = fmt.Sprintf("TOP (%d) ", limitValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -506,7 +517,7 @@ func (session *Session) genUpdateColumns(bean interface{}) ([]string, []interfac
|
||||||
|
|
||||||
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
|
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
|
||||||
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
|
if _, ok := getFlagForColumn(session.statement.nullableMap, col); ok {
|
||||||
if col.Nullable && isZero(fieldValue.Interface()) {
|
if col.Nullable && isZeroValue(fieldValue) {
|
||||||
var nilValue *int
|
var nilValue *int
|
||||||
fieldValue = reflect.ValueOf(nilValue)
|
fieldValue = reflect.ValueOf(nilValue)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1359,3 +1359,90 @@ func TestUpdateAlias(t *testing.T) {
|
||||||
assert.EqualValues(t, 2, ue.NumIssues)
|
assert.EqualValues(t, 2, ue.NumIssues)
|
||||||
assert.EqualValues(t, "lunny xiao", ue.Name)
|
assert.EqualValues(t, "lunny xiao", ue.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateExprs2(t *testing.T) {
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
|
||||||
|
type UpdateExprsRelease struct {
|
||||||
|
Id int64
|
||||||
|
RepoId int
|
||||||
|
IsTag bool
|
||||||
|
IsDraft bool
|
||||||
|
NumCommits int
|
||||||
|
Sha1 string
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSync(t, new(UpdateExprsRelease))
|
||||||
|
|
||||||
|
var uer = UpdateExprsRelease{
|
||||||
|
RepoId: 1,
|
||||||
|
IsTag: false,
|
||||||
|
IsDraft: false,
|
||||||
|
NumCommits: 1,
|
||||||
|
Sha1: "sha1",
|
||||||
|
}
|
||||||
|
inserted, err := testEngine.Insert(&uer)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, inserted)
|
||||||
|
|
||||||
|
updated, err := testEngine.
|
||||||
|
Where("repo_id = ? AND is_tag = ?", 1, false).
|
||||||
|
SetExpr("is_draft", true).
|
||||||
|
SetExpr("num_commits", 0).
|
||||||
|
SetExpr("sha1", "").
|
||||||
|
Update(new(UpdateExprsRelease))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, updated)
|
||||||
|
|
||||||
|
var uer2 UpdateExprsRelease
|
||||||
|
has, err := testEngine.ID(uer.Id).Get(&uer2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
assert.EqualValues(t, 1, uer2.RepoId)
|
||||||
|
assert.EqualValues(t, false, uer2.IsTag)
|
||||||
|
assert.EqualValues(t, true, uer2.IsDraft)
|
||||||
|
assert.EqualValues(t, 0, uer2.NumCommits)
|
||||||
|
assert.EqualValues(t, "", uer2.Sha1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateMap3(t *testing.T) {
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
|
||||||
|
type UpdateMapUser struct {
|
||||||
|
Id uint64 `xorm:"PK autoincr"`
|
||||||
|
Name string `xorm:""`
|
||||||
|
Ver uint64 `xorm:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
oldMapper := testEngine.GetColumnMapper()
|
||||||
|
defer func() {
|
||||||
|
testEngine.SetColumnMapper(oldMapper)
|
||||||
|
}()
|
||||||
|
|
||||||
|
mapper := core.NewPrefixMapper(core.SnakeMapper{}, "F")
|
||||||
|
testEngine.SetColumnMapper(mapper)
|
||||||
|
|
||||||
|
assertSync(t, new(UpdateMapUser))
|
||||||
|
|
||||||
|
_, err := testEngine.Table(new(UpdateMapUser)).Insert(map[string]interface{}{
|
||||||
|
"Fname": "first user name",
|
||||||
|
"Fver": 1,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
update := map[string]interface{}{
|
||||||
|
"Fname": "user name",
|
||||||
|
"Fver": 1,
|
||||||
|
}
|
||||||
|
rows, err := testEngine.Table(new(UpdateMapUser)).ID(1).Update(update)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, rows)
|
||||||
|
|
||||||
|
update = map[string]interface{}{
|
||||||
|
"Name": "user name",
|
||||||
|
"Ver": 1,
|
||||||
|
}
|
||||||
|
rows, err = testEngine.Table(new(UpdateMapUser)).ID(1).Update(update)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.EqualValues(t, 0, rows)
|
||||||
|
}
|
||||||
|
|
33
statement.go
33
statement.go
|
@ -20,7 +20,7 @@ type Statement struct {
|
||||||
RefTable *core.Table
|
RefTable *core.Table
|
||||||
Engine *Engine
|
Engine *Engine
|
||||||
Start int
|
Start int
|
||||||
LimitN int
|
LimitN *int
|
||||||
idParam *core.PK
|
idParam *core.PK
|
||||||
OrderStr string
|
OrderStr string
|
||||||
JoinStr string
|
JoinStr string
|
||||||
|
@ -65,7 +65,7 @@ type Statement struct {
|
||||||
func (statement *Statement) Init() {
|
func (statement *Statement) Init() {
|
||||||
statement.RefTable = nil
|
statement.RefTable = nil
|
||||||
statement.Start = 0
|
statement.Start = 0
|
||||||
statement.LimitN = 0
|
statement.LimitN = nil
|
||||||
statement.OrderStr = ""
|
statement.OrderStr = ""
|
||||||
statement.UseCascade = true
|
statement.UseCascade = true
|
||||||
statement.JoinStr = ""
|
statement.JoinStr = ""
|
||||||
|
@ -247,7 +247,7 @@ func (statement *Statement) buildUpdates(bean interface{},
|
||||||
if !includeVersion && col.IsVersion {
|
if !includeVersion && col.IsVersion {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if col.IsCreated {
|
if col.IsCreated && !columnMap.contain(col.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !includeUpdated && col.IsUpdated {
|
if !includeUpdated && col.IsUpdated {
|
||||||
|
@ -671,7 +671,7 @@ func (statement *Statement) Top(limit int) *Statement {
|
||||||
|
|
||||||
// Limit generate LIMIT start, limit statement
|
// Limit generate LIMIT start, limit statement
|
||||||
func (statement *Statement) Limit(limit int, start ...int) *Statement {
|
func (statement *Statement) Limit(limit int, start ...int) *Statement {
|
||||||
statement.LimitN = limit
|
statement.LimitN = &limit
|
||||||
if len(start) > 0 {
|
if len(start) > 0 {
|
||||||
statement.Start = start[0]
|
statement.Start = start[0]
|
||||||
}
|
}
|
||||||
|
@ -1071,9 +1071,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
|
||||||
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
|
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pLimitN := statement.LimitN
|
||||||
if dialect.DBType() == core.MSSQL {
|
if dialect.DBType() == core.MSSQL {
|
||||||
if statement.LimitN > 0 {
|
if pLimitN != nil {
|
||||||
top = fmt.Sprintf("TOP %d ", statement.LimitN)
|
LimitNValue := *pLimitN
|
||||||
|
top = fmt.Sprintf("TOP %d ", LimitNValue)
|
||||||
}
|
}
|
||||||
if statement.Start > 0 {
|
if statement.Start > 0 {
|
||||||
var column string
|
var column string
|
||||||
|
@ -1134,12 +1136,16 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
|
||||||
if needLimit {
|
if needLimit {
|
||||||
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
|
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
|
||||||
if statement.Start > 0 {
|
if statement.Start > 0 {
|
||||||
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
|
if pLimitN != nil {
|
||||||
} else if statement.LimitN > 0 {
|
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
|
||||||
fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
|
} else {
|
||||||
|
fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
|
||||||
|
}
|
||||||
|
} else if pLimitN != nil {
|
||||||
|
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
|
||||||
}
|
}
|
||||||
} else if dialect.DBType() == core.ORACLE {
|
} else if dialect.DBType() == core.ORACLE {
|
||||||
if statement.Start != 0 || statement.LimitN != 0 {
|
if statement.Start != 0 || pLimitN != nil {
|
||||||
oldString := buf.String()
|
oldString := buf.String()
|
||||||
buf.Reset()
|
buf.Reset()
|
||||||
rawColStr := columnStr
|
rawColStr := columnStr
|
||||||
|
@ -1147,7 +1153,7 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
|
||||||
rawColStr = "at.*"
|
rawColStr = "at.*"
|
||||||
}
|
}
|
||||||
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
|
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
|
||||||
columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
|
columnStr, rawColStr, oldString, statement.Start+*pLimitN, statement.Start)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1204,8 +1210,9 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
var top string
|
var top string
|
||||||
if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
|
pLimitN := statement.LimitN
|
||||||
top = fmt.Sprintf("TOP %d ", statement.LimitN)
|
if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL {
|
||||||
|
top = fmt.Sprintf("TOP %d ", *pLimitN)
|
||||||
}
|
}
|
||||||
|
|
||||||
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
|
newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
|
||||||
|
|
|
@ -69,10 +69,18 @@ func (exprs *exprParams) writeArgs(w *builder.BytesWriter) error {
|
||||||
if _, err := w.WriteString(")"); err != nil {
|
if _, err := w.WriteString(")"); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
default:
|
case string:
|
||||||
|
if arg == "" {
|
||||||
|
arg = "''"
|
||||||
|
}
|
||||||
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
|
if _, err := w.WriteString(fmt.Sprintf("%v", arg)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
if _, err := w.WriteString("?"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.Append(arg)
|
||||||
}
|
}
|
||||||
if i != len(exprs.args)-1 {
|
if i != len(exprs.args)-1 {
|
||||||
if _, err := w.WriteString(","); err != nil {
|
if _, err := w.WriteString(","); err != nil {
|
||||||
|
|
|
@ -7,8 +7,8 @@ package xorm
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"xorm.io/core"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"xorm.io/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDGonicMapper struct {
|
type IDGonicMapper struct {
|
||||||
|
@ -76,7 +76,7 @@ func TestSameMapperID(t *testing.T) {
|
||||||
for _, tb := range tables {
|
for _, tb := range tables {
|
||||||
if tb.Name == "IDSameMapper" {
|
if tb.Name == "IDSameMapper" {
|
||||||
if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" {
|
if len(tb.PKColumns()) != 1 || tb.PKColumns()[0].Name != "ID" {
|
||||||
t.Fatal(tb)
|
t.Fatalf("tb %s tb.PKColumns() is %d not 1, tb.PKColumns()[0].Name is %s not ID", tb.Name, len(tb.PKColumns()), tb.PKColumns()[0].Name)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue