more refactors

This commit is contained in:
Lunny Xiao 2022-05-30 14:15:12 +08:00
parent 274bba2d0b
commit fd823ae5bd
4 changed files with 126 additions and 77 deletions

View File

@ -400,7 +400,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION" " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC"
rows, err := queryer.QueryContext(ctx, s, args...) rows, err := queryer.QueryContext(ctx, s, args...)
if err != nil { if err != nil {

View File

@ -59,11 +59,10 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return "", nil, err return "", nil, err
} }
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) sqlStr, args, err := statement.genSelectSQL(columnStr, true, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
args := append(statement.joinArgs, condArgs...)
// for mssql and use limit // for mssql and use limit
qs := strings.Count(sqlStr, "?") qs := strings.Count(sqlStr, "?")
@ -99,12 +98,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
return "", nil, err return "", nil, err
} }
sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) return statement.genSelectSQL(sumSelect, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenGetSQL generates Get SQL // GenGetSQL generates Get SQL
@ -156,12 +150,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{},
} }
} }
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) return statement.genSelectSQL(columnStr, true, true)
if err != nil {
return "", nil, err
}
return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenCountSQL generates the SQL for counting // GenCountSQL generates the SQL for counting
@ -207,42 +196,81 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa
sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr)
} }
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, condArgs, nil
} }
func (statement *Statement) fromBuilder() *strings.Builder { func (statement *Statement) writeJoin(w builder.Writer) error {
var builder strings.Builder
quote := statement.quote
dialect := statement.dialect
builder.WriteString(" FROM ")
if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
builder.WriteString(statement.TableName())
} else {
builder.WriteString(quote(statement.TableName()))
}
if statement.TableAlias != "" {
if dialect.URI().DBType == schemas.ORACLE {
builder.WriteString(" ")
} else {
builder.WriteString(" AS ")
}
builder.WriteString(quote(statement.TableAlias))
}
if statement.JoinStr != "" { if statement.JoinStr != "" {
builder.WriteString(" ") if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil {
builder.WriteString(statement.JoinStr) return err
} }
return &builder w.Append(statement.joinArgs...)
}
return nil
}
func (statement *Statement) writeAlias(w builder.Writer) error {
if statement.TableAlias != "" {
if statement.dialect.URI().DBType == schemas.ORACLE {
if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, " AS ", statement.quote(statement.TableAlias)); err != nil {
return err
}
}
}
return nil
}
func (statement *Statement) writeTableName(w builder.Writer) error {
if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") {
if _, err := fmt.Fprint(w, statement.TableName()); err != nil {
return err
}
} else {
if _, err := fmt.Fprint(w, statement.quote(statement.TableName())); err != nil {
return err
}
}
return nil
}
func (statement *Statement) writeFrom(w builder.Writer) error {
if _, err := fmt.Fprint(w, " FROM "); err != nil {
return err
}
if err := statement.writeTableName(w); err != nil {
return err
}
if err := statement.writeAlias(w); err != nil {
return err
}
return statement.writeJoin(w)
}
func (statement *Statement) writeLimitOffset(w builder.Writer) error {
if statement.Start > 0 {
if statement.LimitN != nil {
_, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start)
return err
}
_, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start)
return err
}
if statement.LimitN != nil {
_, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN)
return err
}
// no limit statement
return nil
} }
func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) {
var ( var (
distinct string distinct string
dialect = statement.dialect dialect = statement.dialect
fromStr = statement.fromBuilder().String()
top, whereStr string top, whereStr string
mssqlCondi = builder.NewWriter() mssqlCondi = builder.NewWriter()
) )
@ -292,25 +320,29 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
} }
orderByWriter := builder.NewWriter() if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s",
column, statement.Start, column); err != nil {
return "", nil, err
}
if err := statement.writeFrom(mssqlCondi); err != nil {
return "", nil, err
}
if whereStr != "" {
if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(mssqlCondi, condWriter); err != nil {
return "", nil, err
}
}
if needOrderBy { if needOrderBy {
if err := statement.WriteOrderBy(orderByWriter); err != nil { if err := statement.WriteOrderBy(mssqlCondi); err != nil {
return "", nil, err return "", nil, err
} }
} }
if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s",
column, statement.Start, column, fromStr, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(mssqlCondi, condWriter, orderByWriter); err != nil {
return "", nil, err
}
if err := statement.WriteGroupBy(mssqlCondi); err != nil { if err := statement.WriteGroupBy(mssqlCondi); err != nil {
return "", nil, err return "", nil, err
} }
if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil { if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil {
return "", nil, err return "", nil, err
} }
@ -318,15 +350,29 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
buf := builder.NewWriter() buf := builder.NewWriter()
fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil {
return "", nil, err
}
if err := statement.writeFrom(buf); err != nil {
return "", nil, err
}
if whereStr != "" {
if _, err := fmt.Fprint(buf, whereStr); err != nil {
return "", nil, err
}
if err := utils.WriteBuilder(buf, condWriter); err != nil { if err := utils.WriteBuilder(buf, condWriter); err != nil {
return "", nil, err return "", nil, err
} }
}
if mssqlCondi.Len() > 0 { if mssqlCondi.Len() > 0 {
if len(whereStr) > 0 { if len(whereStr) > 0 {
fmt.Fprint(buf, " AND ") if _, err := fmt.Fprint(buf, " AND "); err != nil {
return "", nil, err
}
} else { } else {
fmt.Fprint(buf, " WHERE ") if _, err := fmt.Fprint(buf, " WHERE "); err != nil {
return "", nil, err
}
} }
if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { if err := utils.WriteBuilder(buf, mssqlCondi); err != nil {
@ -337,8 +383,8 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
if err := statement.WriteGroupBy(buf); err != nil { if err := statement.WriteGroupBy(buf); err != nil {
return "", nil, err return "", nil, err
} }
if statement.HavingStr != "" { if err := statement.writeHaving(buf); err != nil {
fmt.Fprint(buf, " ", statement.HavingStr) return "", nil, err
} }
if needOrderBy { if needOrderBy {
if err := statement.WriteOrderBy(buf); err != nil { if err := statement.WriteOrderBy(buf); err != nil {
@ -347,14 +393,8 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
} }
if needLimit { if needLimit {
if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE {
if statement.Start > 0 { if err := statement.writeLimitOffset(buf); err != nil {
if pLimitN != nil { return "", nil, err
fmt.Fprintf(buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(buf, " LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(buf, " LIMIT ", *pLimitN)
} }
} else if dialect.URI().DBType == schemas.ORACLE { } else if dialect.URI().DBType == schemas.ORACLE {
if pLimitN != nil { if pLimitN != nil {
@ -500,11 +540,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) sqlStr, args, err = statement.genSelectSQL(columnStr, true, true)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
args = append(statement.joinArgs, condArgs...)
// for mssql and use limit // for mssql and use limit
qs := strings.Count(sqlStr, "?") qs := strings.Count(sqlStr, "?")
if len(args)*2 == qs { if len(args)*2 == qs {

View File

@ -603,11 +603,11 @@ func (statement *Statement) GroupBy(keys string) *Statement {
} }
func (statement *Statement) WriteGroupBy(w builder.Writer) error { func (statement *Statement) WriteGroupBy(w builder.Writer) error {
if len(statement.GroupByStr) > 0 { if statement.GroupByStr == "" {
return nil
}
_, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr) _, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr)
return err return err
}
return nil
} }
// Having generate "Having conditions" statement // Having generate "Having conditions" statement
@ -616,6 +616,14 @@ func (statement *Statement) Having(conditions string) *Statement {
return statement return statement
} }
func (statement *Statement) writeHaving(w builder.Writer) error {
if statement.HavingStr == "" {
return nil
}
_, err := fmt.Fprint(w, " ", statement.HavingStr)
return err
}
// SetUnscoped always disable struct tag "deleted" // SetUnscoped always disable struct tag "deleted"
func (statement *Statement) SetUnscoped() *Statement { func (statement *Statement) SetUnscoped() *Statement {
statement.unscoped = true statement.unscoped = true

View File

@ -414,15 +414,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
updateWriter := builder.NewWriter() updateWriter := builder.NewWriter()
if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v%v", if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v",
top, top,
tableAlias, tableAlias,
strings.Join(colNames, ", "), strings.Join(colNames, ", "),
fromSQL, fromSQL); err != nil {
whereWriter.String()); err != nil { return 0, err
}
if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil {
return 0, err return 0, err
} }
updateWriter.Append(whereWriter.Args()...)
res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...)
if err != nil { if err != nil {