Add insert select where support (#1401)
This commit is contained in:
parent
b78ac8ce0a
commit
17592d96b3
|
@ -12,6 +12,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"xorm.io/builder"
|
||||||
"xorm.io/core"
|
"xorm.io/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -345,7 +346,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
for _, v := range exprColumns {
|
for _, v := range exprColumns {
|
||||||
// remove the expr columns
|
// remove the expr columns
|
||||||
for i, colName := range colNames {
|
for i, colName := range colNames {
|
||||||
if colName == v.colName {
|
if colName == strings.Trim(v.colName, "`") {
|
||||||
colNames = append(colNames[:i], colNames[i+1:]...)
|
colNames = append(colNames[:i], colNames[i+1:]...)
|
||||||
args = append(args[:i], args[i+1:]...)
|
args = append(args[:i], args[i+1:]...)
|
||||||
}
|
}
|
||||||
|
@ -371,12 +372,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
||||||
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
|
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
|
||||||
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
|
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(colPlaces) > 0 {
|
if len(colPlaces) > 0 {
|
||||||
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
|
if session.statement.cond.IsValid() {
|
||||||
session.engine.Quote(tableName),
|
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
|
||||||
quoteColumns(colNames, session.engine.Quote, ","),
|
if err != nil {
|
||||||
output,
|
return 0, err
|
||||||
colPlaces)
|
}
|
||||||
|
|
||||||
|
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v",
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
quoteColumns(colNames, session.engine.Quote, ","),
|
||||||
|
output,
|
||||||
|
colPlaces,
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
condSQL,
|
||||||
|
)
|
||||||
|
args = append(args, condArgs...)
|
||||||
|
} else {
|
||||||
|
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
quoteColumns(colNames, session.engine.Quote, ","),
|
||||||
|
output,
|
||||||
|
colPlaces)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if session.engine.dialect.DBType() == core.MYSQL {
|
if session.engine.dialect.DBType() == core.MYSQL {
|
||||||
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
|
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
|
||||||
|
@ -663,6 +682,11 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
|
||||||
return 0, ErrParamsType
|
return 0, ErrParamsType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tableName := session.statement.TableName()
|
||||||
|
if len(tableName) <= 0 {
|
||||||
|
return 0, ErrTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
var columns = make([]string, 0, len(m))
|
var columns = make([]string, 0, len(m))
|
||||||
for k := range m {
|
for k := range m {
|
||||||
columns = append(columns, k)
|
columns = append(columns, k)
|
||||||
|
@ -670,19 +694,40 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
|
||||||
sort.Strings(columns)
|
sort.Strings(columns)
|
||||||
|
|
||||||
qm := strings.Repeat("?,", len(columns))
|
qm := strings.Repeat("?,", len(columns))
|
||||||
qm = "(" + qm[:len(qm)-1] + ")"
|
|
||||||
|
|
||||||
tableName := session.statement.TableName()
|
|
||||||
if len(tableName) <= 0 {
|
|
||||||
return 0, ErrTableNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
|
|
||||||
var args = make([]interface{}, 0, len(m))
|
var args = make([]interface{}, 0, len(m))
|
||||||
for _, colName := range columns {
|
for _, colName := range columns {
|
||||||
args = append(args, m[colName])
|
args = append(args, m[colName])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// insert expr columns, override if exists
|
||||||
|
exprColumns := session.statement.getExpr()
|
||||||
|
for _, col := range exprColumns {
|
||||||
|
columns = append(columns, strings.Trim(col.colName, "`"))
|
||||||
|
qm = qm + col.expr + ","
|
||||||
|
}
|
||||||
|
|
||||||
|
qm = qm[:len(qm)-1]
|
||||||
|
|
||||||
|
var sql string
|
||||||
|
|
||||||
|
if session.statement.cond.IsValid() {
|
||||||
|
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
strings.Join(columns, "`,`"),
|
||||||
|
qm,
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
condSQL,
|
||||||
|
)
|
||||||
|
args = append(args, condArgs...)
|
||||||
|
} else {
|
||||||
|
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
|
||||||
|
}
|
||||||
|
|
||||||
if err := session.cacheInsert(tableName); err != nil {
|
if err := session.cacheInsert(tableName); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -703,26 +748,53 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
|
||||||
return 0, ErrParamsType
|
return 0, ErrParamsType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tableName := session.statement.TableName()
|
||||||
|
if len(tableName) <= 0 {
|
||||||
|
return 0, ErrTableNotFound
|
||||||
|
}
|
||||||
|
|
||||||
var columns = make([]string, 0, len(m))
|
var columns = make([]string, 0, len(m))
|
||||||
for k := range m {
|
for k := range m {
|
||||||
columns = append(columns, k)
|
columns = append(columns, k)
|
||||||
}
|
}
|
||||||
sort.Strings(columns)
|
sort.Strings(columns)
|
||||||
|
|
||||||
qm := strings.Repeat("?,", len(columns))
|
|
||||||
qm = "(" + qm[:len(qm)-1] + ")"
|
|
||||||
|
|
||||||
tableName := session.statement.TableName()
|
|
||||||
if len(tableName) <= 0 {
|
|
||||||
return 0, ErrTableNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
|
|
||||||
var args = make([]interface{}, 0, len(m))
|
var args = make([]interface{}, 0, len(m))
|
||||||
for _, colName := range columns {
|
for _, colName := range columns {
|
||||||
args = append(args, m[colName])
|
args = append(args, m[colName])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
qm := strings.Repeat("?,", len(columns))
|
||||||
|
|
||||||
|
// insert expr columns, override if exists
|
||||||
|
exprColumns := session.statement.getExpr()
|
||||||
|
for _, col := range exprColumns {
|
||||||
|
columns = append(columns, strings.Trim(col.colName, "`"))
|
||||||
|
qm = qm + col.expr + ","
|
||||||
|
}
|
||||||
|
|
||||||
|
qm = qm[:len(qm)-1]
|
||||||
|
|
||||||
|
var sql string
|
||||||
|
|
||||||
|
if session.statement.cond.IsValid() {
|
||||||
|
qm = "(" + qm[:len(qm)-1] + ")"
|
||||||
|
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
strings.Join(columns, "`,`"),
|
||||||
|
qm,
|
||||||
|
session.engine.Quote(tableName),
|
||||||
|
condSQL,
|
||||||
|
)
|
||||||
|
args = append(args, condArgs...)
|
||||||
|
} else {
|
||||||
|
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
|
||||||
|
}
|
||||||
|
|
||||||
if err := session.cacheInsert(tableName); err != nil {
|
if err := session.cacheInsert(tableName); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -834,3 +834,62 @@ func TestInsertMap(t *testing.T) {
|
||||||
assert.EqualValues(t, 10, ims[3].Height)
|
assert.EqualValues(t, 10, ims[3].Height)
|
||||||
assert.EqualValues(t, "lunny", ims[3].Name)
|
assert.EqualValues(t, "lunny", ims[3].Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`)
|
||||||
|
SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1;
|
||||||
|
*/
|
||||||
|
func TestInsertWhere(t *testing.T) {
|
||||||
|
type InsertWhere struct {
|
||||||
|
Id int64
|
||||||
|
Index int `xorm:"unique(s) notnull"`
|
||||||
|
RepoId int64 `xorm:"unique(s)"`
|
||||||
|
Width uint32
|
||||||
|
Height uint32
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, prepareEngine())
|
||||||
|
assertSync(t, new(InsertWhere))
|
||||||
|
|
||||||
|
var i = InsertWhere{
|
||||||
|
RepoId: 1,
|
||||||
|
Width: 10,
|
||||||
|
Height: 20,
|
||||||
|
Name: "trest",
|
||||||
|
}
|
||||||
|
|
||||||
|
inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
|
||||||
|
Where("repo_id=?", 1).
|
||||||
|
Insert(&i)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, inserted)
|
||||||
|
assert.EqualValues(t, 1, i.Id)
|
||||||
|
|
||||||
|
var j InsertWhere
|
||||||
|
has, err := testEngine.ID(i.Id).Get(&j)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
i.Index = 1
|
||||||
|
assert.EqualValues(t, i, j)
|
||||||
|
|
||||||
|
inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
|
||||||
|
SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
|
||||||
|
Insert(map[string]interface{}{
|
||||||
|
"repo_id": 1,
|
||||||
|
"width": 20,
|
||||||
|
"height": 40,
|
||||||
|
"name": "trest2",
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, inserted)
|
||||||
|
|
||||||
|
var j2 InsertWhere
|
||||||
|
has, err = testEngine.ID(2).Get(&j2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, has)
|
||||||
|
assert.EqualValues(t, 1, j2.RepoId)
|
||||||
|
assert.EqualValues(t, 20, j2.Width)
|
||||||
|
assert.EqualValues(t, 40, j2.Height)
|
||||||
|
assert.EqualValues(t, "trest2", j2.Name)
|
||||||
|
assert.EqualValues(t, 2, j2.Index)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue