mostly complete

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2023-03-13 06:34:31 +00:00
parent 3a4fbeaa6f
commit c55adf1c26
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
2 changed files with 129 additions and 45 deletions

View File

@ -24,15 +24,29 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
table = statement.RefTable table = statement.RefTable
tableName = statement.TableName() tableName = statement.TableName()
) )
quote := statement.dialect.Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = buf.WriteString(arg)
}
}
var updateColumns []string
if doUpdate {
updateColumns = make([]string, 0, len(columns))
for _, column := range columns {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
continue
}
updateColumns = append(updateColumns, quote(column))
}
doUpdate = doUpdate && (len(updateColumns) > 0)
}
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
if _, err := buf.WriteString("INSERT IGNORE INTO "); err != nil { write("INSERT IGNORE INTO ")
return "", nil, err
}
} else { } else {
if _, err := buf.WriteString("INSERT INTO "); err != nil { write("INSERT INTO ")
return "", nil, err
}
} }
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil { if err := statement.dialect.Quoter().QuoteTo(buf.Builder, tableName); err != nil {
@ -45,28 +59,31 @@ func (statement *Statement) GenUpsertSQL(doUpdate bool, columns []string, args [
switch statement.dialect.URI().DBType { switch statement.dialect.URI().DBType {
case schemas.SQLITE, schemas.POSTGRES: case schemas.SQLITE, schemas.POSTGRES:
if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil { write(" ON CONFLICT DO ")
return "", nil, err
}
if doUpdate { if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
} else { for _, column := range updateColumns[1:] {
if _, err := buf.WriteString("NOTHING"); err != nil { write(", ", column, " = excluded.", column)
return "", nil, err
} }
} else {
write("NOTHING")
} }
case schemas.MYSQL: case schemas.MYSQL:
if doUpdate { if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT // FIXME: mysql >= 8.0.19 should use table alias
write(" ON DUPLICATE KEY ")
write("UPDATE SET ", updateColumns[0], " = VALUES(", updateColumns[0], ")")
for _, column := range updateColumns[1:] {
write(", ", column, " = VALUES(", column, ")")
}
} }
default: default:
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
} }
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || statement.dialect.URI().DBType == schemas.SQLITE) {
if _, err := buf.WriteString(" RETURNING "); err != nil { write(" RETURNING ")
return "", nil, err
}
if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil { if err := statement.dialect.Quoter().QuoteTo(buf.Builder, table.AutoIncrement); err != nil {
return "", nil, err return "", nil, err
} }
@ -91,9 +108,9 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
write("MERGE INTO ", quote(tableName)) write("MERGE INTO ", quote(tableName))
if statement.dialect.URI().DBType == schemas.MSSQL { if statement.dialect.URI().DBType == schemas.MSSQL {
write("WITH (HOLDLOCK) AS target ") write(" WITH (HOLDLOCK)")
} }
write("USING (SELECT ") write(" AS target USING (SELECT ")
uniqueCols := make([]string, 0, len(uniqueColValMap)) uniqueCols := make([]string, 0, len(uniqueColValMap))
for colName := range uniqueColValMap { for colName := range uniqueColValMap {
@ -108,6 +125,19 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
write(", ") write(", ")
} }
} }
var updateColumns []string
if doUpdate {
updateColumns = make([]string, 0, len(columns))
for _, column := range columns {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
continue
}
updateColumns = append(updateColumns, quote(column))
}
doUpdate = doUpdate && (len(updateColumns) > 0)
}
write(") AS src ON (") write(") AS src ON (")
countUniques := 0 countUniques := 0
@ -126,10 +156,15 @@ func (statement *Statement) genMergeSQL(doUpdate bool, columns []string, args []
} }
write(")") write(")")
} }
write(")")
if doUpdate { if doUpdate {
return "", nil, fmt.Errorf("unimplemented") write(" WHEN MATCHED THEN UPDATE SET ")
write("src.", quote(updateColumns[0]), "= target.", quote(updateColumns[0]))
for _, col := range updateColumns[1:] {
write(", src.", quote(col), "= target.", quote(col))
}
} }
write(") WHEN NOT MATCHED THEN INSERT") write(" WHEN NOT MATCHED THEN INSERT ")
if err := statement.genInsertValues(buf, columns, args); err != nil { if err := statement.genInsertValues(buf, columns, args); err != nil {
return "", nil, err return "", nil, err
} }
@ -155,6 +190,17 @@ func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, arg
_, _ = buf.WriteString(arg) _, _ = buf.WriteString(arg)
} }
} }
var updateColumns []string
if doUpdate {
updateColumns = make([]string, 0, len(columns))
for _, column := range append(columns, exprs.ColNames()...) {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
continue
}
updateColumns = append(updateColumns, quote(column))
}
doUpdate = doUpdate && (len(updateColumns) > 0)
}
if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate { if statement.dialect.URI().DBType == schemas.MYSQL && !doUpdate {
write("INSERT IGNORE INTO ", quote(tableName), " (") write("INSERT IGNORE INTO ", quote(tableName), " (")
@ -172,25 +218,29 @@ func (statement *Statement) GenUpsertMapSQL(doUpdate bool, columns []string, arg
switch statement.dialect.URI().DBType { switch statement.dialect.URI().DBType {
case schemas.SQLITE, schemas.POSTGRES: case schemas.SQLITE, schemas.POSTGRES:
if _, err := buf.WriteString(" ON CONFLICT DO "); err != nil { write(" ON CONFLICT DO ")
return "", nil, err
}
if doUpdate { if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT write("UPDATE SET ", updateColumns[0], " = excluded.", updateColumns[0])
} else { for _, column := range updateColumns[1:] {
if _, err := buf.WriteString("NOTHING"); err != nil { write(", ", column, " = excluded.", column)
return "", nil, err
} }
} else {
write("NOTHING")
} }
case schemas.MYSQL: case schemas.MYSQL:
if doUpdate { if doUpdate {
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT // FIXME: mysql >= 8.0.19 should use table alias
write(" ON DUPLICATE KEY ")
write("UPDATE SET ", updateColumns[0], " = VALUES(", updateColumns[0], ")")
for _, column := range updateColumns[1:] {
write(", ", column, " = VALUES(", column, ")")
}
} }
default: default:
return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT return "", nil, fmt.Errorf("unimplemented") // FIXME: UPSERT
} }
if len(table.AutoIncrement) > 0 && statement.dialect.URI().DBType == schemas.POSTGRES { if len(table.AutoIncrement) > 0 && (statement.dialect.URI().DBType == schemas.POSTGRES || statement.dialect.URI().DBType == schemas.SQLITE) {
if _, err := buf.WriteString(" RETURNING "); err != nil { if _, err := buf.WriteString(" RETURNING "); err != nil {
return "", nil, err return "", nil, err
} }
@ -219,9 +269,9 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args
write("MERGE INTO ", quote(tableName)) write("MERGE INTO ", quote(tableName))
if statement.dialect.URI().DBType == schemas.MSSQL { if statement.dialect.URI().DBType == schemas.MSSQL {
write("WITH (HOLDLOCK) AS target ") write(" WITH (HOLDLOCK)")
} }
write("USING (SELECT ") write(" AS target USING (SELECT ")
uniqueCols := make([]string, 0, len(uniqueColValMap)) uniqueCols := make([]string, 0, len(uniqueColValMap))
for colName := range uniqueColValMap { for colName := range uniqueColValMap {
@ -236,6 +286,18 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args
write(", ") write(", ")
} }
} }
var updateColumns []string
if doUpdate {
updateColumns = make([]string, 0, len(columns))
for _, column := range append(columns, exprs.ColNames()...) {
if _, has := uniqueColValMap[schemas.CommonQuoter.Trim(column)]; has {
continue
}
updateColumns = append(updateColumns, quote(column))
}
doUpdate = doUpdate && (len(updateColumns) > 0)
}
write(") AS src ON (") write(") AS src ON (")
countUniques := 0 countUniques := 0
@ -254,10 +316,16 @@ func (statement *Statement) genMergeMapSQL(doUpdate bool, columns []string, args
} }
write(")") write(")")
} }
write(")")
if doUpdate { if doUpdate {
return "", nil, fmt.Errorf("unimplemented") write(" WHEN MATCHED THEN UPDATE SET ")
write("src.", quote(updateColumns[0]), "= target.", quote(updateColumns[0]))
for _, col := range updateColumns[1:] {
write(", src.", quote(col), "= target.", quote(col))
}
} }
write(") WHEN NOT MATCHED THEN INSERT") write(" WHEN NOT MATCHED THEN INSERT ")
if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil { if err := statement.dialect.Quoter().JoinWrite(buf.Builder, append(columns, exprs.ColNames()...), ","); err != nil {
return "", nil, err return "", nil, err
} }

View File

@ -37,7 +37,7 @@ func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, erro
switch v := bean.(type) { switch v := bean.(type) {
case map[string]interface{}: case map[string]interface{}:
cnt, err = session.upsertMapInterface(doUpdate, v) cnt, err = session.upsertMapInterface(doUpdate, v)
case []map[string]interface{}: case []map[string]interface{}: // FIXME: handle multiple
for _, m := range v { for _, m := range v {
cnt, err := session.upsertMapInterface(doUpdate, m) cnt, err := session.upsertMapInterface(doUpdate, m)
if err != nil { if err != nil {
@ -47,7 +47,7 @@ func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, erro
} }
case map[string]string: case map[string]string:
cnt, err = session.upsertMapString(doUpdate, v) cnt, err = session.upsertMapString(doUpdate, v)
case []map[string]string: case []map[string]string: // FIXME: handle multiple
for _, m := range v { for _, m := range v {
cnt, err := session.upsertMapString(doUpdate, m) cnt, err := session.upsertMapString(doUpdate, m)
if err != nil { if err != nil {
@ -57,7 +57,7 @@ func (session *Session) upsert(doUpdate bool, beans ...interface{}) (int64, erro
} }
default: default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean)) sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice { if sliceValue.Kind() == reflect.Slice { // FIXME: handle multiple
if sliceValue.Len() <= 0 { if sliceValue.Len() <= 0 {
return 0, ErrNoElementsOnSlice return 0, ErrNoElementsOnSlice
} }
@ -140,7 +140,7 @@ func (session *Session) upsertMap(doUpdate bool, columns []string, args []interf
if err != nil { if err != nil {
return 0, err return 0, err
} }
return affected, fmt.Errorf("unimplemented") return affected, nil
} }
func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) { func (session *Session) upsertStruct(doUpdate bool, bean interface{}) (int64, error) {
@ -273,21 +273,37 @@ func (session *Session) getUniqueColumns(colNames []string, args []interface{})
} }
if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement { if indexColumn.MapType == schemas.ONLYFROMDB || indexColumn.IsAutoIncrement {
continue continue indexCol
} }
// FIXME: what do we do here?!
if session.statement.OmitColumnMap.Contain(indexColumn.Name) { if session.statement.OmitColumnMap.Contain(indexColumn.Name) {
continue continue indexCol
} }
// FIXME: what do we do here?!
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) { if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(indexColumn.Name) {
continue continue indexCol
} }
// FIXME: what do we do here?! // FIXME: what do we do here?!
if session.statement.IncrColumns.IsColExist(indexColumn.Name) { if session.statement.IncrColumns.IsColExist(indexColumn.Name) {
continue for _, exprCol := range session.statement.IncrColumns {
if exprCol.ColName == indexColumn.Name {
uniqueColValMap[indexColumnName] = exprCol.Arg
}
}
continue indexCol
} else if session.statement.DecrColumns.IsColExist(indexColumn.Name) { } else if session.statement.DecrColumns.IsColExist(indexColumn.Name) {
continue for _, exprCol := range session.statement.DecrColumns {
if exprCol.ColName == indexColumn.Name {
uniqueColValMap[indexColumnName] = exprCol.Arg
}
}
continue indexCol
} else if session.statement.ExprColumns.IsColExist(indexColumn.Name) { } else if session.statement.ExprColumns.IsColExist(indexColumn.Name) {
continue for _, exprCol := range session.statement.ExprColumns {
if exprCol.ColName == indexColumn.Name {
uniqueColValMap[indexColumnName] = exprCol.Arg
}
}
} }
// FIXME: not sure if there's anything else we can do // FIXME: not sure if there's anything else we can do