diff --git a/dialects/mysql.go b/dialects/mysql.go index 82df04dd..31e7b788 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -400,7 +400,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " + alreadyQuoted + " AS NEEDS_QUOTE " + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + - " ORDER BY `COLUMNS`.ORDINAL_POSITION" + " ORDER BY `COLUMNS`.ORDINAL_POSITION ASC" rows, err := queryer.QueryContext(ctx, s, args...) if err != nil { diff --git a/internal/statements/query.go b/internal/statements/query.go index e4f53f95..9d16c891 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -59,11 +59,10 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return "", nil, err } - sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + sqlStr, args, err := statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } - args := append(statement.joinArgs, condArgs...) // for mssql and use limit qs := strings.Count(sqlStr, "?") @@ -99,12 +98,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri return "", nil, err } - sqlStr, condArgs, err := statement.genSelectSQL(sumSelect, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil + return statement.genSelectSQL(sumSelect, true, true) } // GenGetSQL generates Get SQL @@ -156,12 +150,7 @@ func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, } } - sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) - if err != nil { - return "", nil, err - } - - return sqlStr, append(statement.joinArgs, condArgs...), nil + return statement.genSelectSQL(columnStr, true, true) } // GenCountSQL generates the SQL for counting @@ -207,42 +196,81 @@ func (statement *Statement) GenCountSQL(beans ...interface{}) (string, []interfa sqlStr = fmt.Sprintf("SELECT %s FROM (%s) sub", selectSQL, sqlStr) } - return sqlStr, append(statement.joinArgs, condArgs...), nil + return sqlStr, condArgs, nil } -func (statement *Statement) fromBuilder() *strings.Builder { - var builder strings.Builder - quote := statement.quote - dialect := statement.dialect - - builder.WriteString(" FROM ") - - if dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { - builder.WriteString(statement.TableName()) - } else { - builder.WriteString(quote(statement.TableName())) - } - - if statement.TableAlias != "" { - if dialect.URI().DBType == schemas.ORACLE { - builder.WriteString(" ") - } else { - builder.WriteString(" AS ") - } - builder.WriteString(quote(statement.TableAlias)) - } +func (statement *Statement) writeJoin(w builder.Writer) error { if statement.JoinStr != "" { - builder.WriteString(" ") - builder.WriteString(statement.JoinStr) + if _, err := fmt.Fprint(w, " ", statement.JoinStr); err != nil { + return err + } + w.Append(statement.joinArgs...) } - return &builder + return nil +} + +func (statement *Statement) writeAlias(w builder.Writer) error { + if statement.TableAlias != "" { + if statement.dialect.URI().DBType == schemas.ORACLE { + if _, err := fmt.Fprint(w, " ", statement.quote(statement.TableAlias)); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, " AS ", statement.quote(statement.TableAlias)); err != nil { + return err + } + } + } + return nil +} + +func (statement *Statement) writeTableName(w builder.Writer) error { + if statement.dialect.URI().DBType == schemas.MSSQL && strings.Contains(statement.TableName(), "..") { + if _, err := fmt.Fprint(w, statement.TableName()); err != nil { + return err + } + } else { + if _, err := fmt.Fprint(w, statement.quote(statement.TableName())); err != nil { + return err + } + } + return nil +} + +func (statement *Statement) writeFrom(w builder.Writer) error { + if _, err := fmt.Fprint(w, " FROM "); err != nil { + return err + } + if err := statement.writeTableName(w); err != nil { + return err + } + if err := statement.writeAlias(w); err != nil { + return err + } + return statement.writeJoin(w) +} + +func (statement *Statement) writeLimitOffset(w builder.Writer) error { + if statement.Start > 0 { + if statement.LimitN != nil { + _, err := fmt.Fprintf(w, " LIMIT %v OFFSET %v", *statement.LimitN, statement.Start) + return err + } + _, err := fmt.Fprintf(w, " LIMIT 0 OFFSET %v", statement.Start) + return err + } + if statement.LimitN != nil { + _, err := fmt.Fprint(w, " LIMIT ", *statement.LimitN) + return err + } + // no limit statement + return nil } func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderBy bool) (string, []interface{}, error) { var ( distinct string dialect = statement.dialect - fromStr = statement.fromBuilder().String() top, whereStr string mssqlCondi = builder.NewWriter() ) @@ -292,25 +320,29 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } } - orderByWriter := builder.NewWriter() - if needOrderBy { - if err := statement.WriteOrderBy(orderByWriter); err != nil { + if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s", + column, statement.Start, column); err != nil { + return "", nil, err + } + if err := statement.writeFrom(mssqlCondi); err != nil { + return "", nil, err + } + if whereStr != "" { + if _, err := fmt.Fprint(mssqlCondi, whereStr); err != nil { + return "", nil, err + } + if err := utils.WriteBuilder(mssqlCondi, condWriter); err != nil { return "", nil, err } } - - if _, err := fmt.Fprintf(mssqlCondi, "(%s NOT IN (SELECT TOP %d %s%s%s", - column, statement.Start, column, fromStr, whereStr); err != nil { - return "", nil, err + if needOrderBy { + if err := statement.WriteOrderBy(mssqlCondi); err != nil { + return "", nil, err + } } - if err := utils.WriteBuilder(mssqlCondi, condWriter, orderByWriter); err != nil { - return "", nil, err - } - if err := statement.WriteGroupBy(mssqlCondi); err != nil { return "", nil, err } - if _, err := fmt.Fprint(mssqlCondi, "))"); err != nil { return "", nil, err } @@ -318,15 +350,29 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } buf := builder.NewWriter() - fmt.Fprintf(buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) - if err := utils.WriteBuilder(buf, condWriter); err != nil { + if _, err := fmt.Fprintf(buf, "SELECT %v%v%v", distinct, top, columnStr); err != nil { return "", nil, err } + if err := statement.writeFrom(buf); err != nil { + return "", nil, err + } + if whereStr != "" { + if _, err := fmt.Fprint(buf, whereStr); err != nil { + return "", nil, err + } + if err := utils.WriteBuilder(buf, condWriter); err != nil { + return "", nil, err + } + } if mssqlCondi.Len() > 0 { if len(whereStr) > 0 { - fmt.Fprint(buf, " AND ") + if _, err := fmt.Fprint(buf, " AND "); err != nil { + return "", nil, err + } } else { - fmt.Fprint(buf, " WHERE ") + if _, err := fmt.Fprint(buf, " WHERE "); err != nil { + return "", nil, err + } } if err := utils.WriteBuilder(buf, mssqlCondi); err != nil { @@ -337,8 +383,8 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB if err := statement.WriteGroupBy(buf); err != nil { return "", nil, err } - if statement.HavingStr != "" { - fmt.Fprint(buf, " ", statement.HavingStr) + if err := statement.writeHaving(buf); err != nil { + return "", nil, err } if needOrderBy { if err := statement.WriteOrderBy(buf); err != nil { @@ -347,14 +393,8 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB } if needLimit { if dialect.URI().DBType != schemas.MSSQL && dialect.URI().DBType != schemas.ORACLE { - if statement.Start > 0 { - if pLimitN != nil { - fmt.Fprintf(buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start) - } else { - fmt.Fprintf(buf, " LIMIT 0 OFFSET %v", statement.Start) - } - } else if pLimitN != nil { - fmt.Fprint(buf, " LIMIT ", *pLimitN) + if err := statement.writeLimitOffset(buf); err != nil { + return "", nil, err } } else if dialect.URI().DBType == schemas.ORACLE { if pLimitN != nil { @@ -500,11 +540,11 @@ func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interfa statement.cond = statement.cond.And(autoCond) - sqlStr, condArgs, err := statement.genSelectSQL(columnStr, true, true) + sqlStr, args, err = statement.genSelectSQL(columnStr, true, true) if err != nil { return "", nil, err } - args = append(statement.joinArgs, condArgs...) + // for mssql and use limit qs := strings.Count(sqlStr, "?") if len(args)*2 == qs { diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 8250921e..605d4bd7 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -603,11 +603,11 @@ func (statement *Statement) GroupBy(keys string) *Statement { } func (statement *Statement) WriteGroupBy(w builder.Writer) error { - if len(statement.GroupByStr) > 0 { - _, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr) - return err + if statement.GroupByStr == "" { + return nil } - return nil + _, err := fmt.Fprintf(w, " GROUP BY %s", statement.GroupByStr) + return err } // Having generate "Having conditions" statement @@ -616,6 +616,14 @@ func (statement *Statement) Having(conditions string) *Statement { return statement } +func (statement *Statement) writeHaving(w builder.Writer) error { + if statement.HavingStr == "" { + return nil + } + _, err := fmt.Fprint(w, " ", statement.HavingStr) + return err +} + // SetUnscoped always disable struct tag "deleted" func (statement *Statement) SetUnscoped() *Statement { statement.unscoped = true diff --git a/session_update.go b/session_update.go index 235fa5b0..66e5d980 100644 --- a/session_update.go +++ b/session_update.go @@ -414,15 +414,16 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } updateWriter := builder.NewWriter() - if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v%v", + if _, err := fmt.Fprintf(updateWriter, "UPDATE %v%v SET %v %v", top, tableAlias, strings.Join(colNames, ", "), - fromSQL, - whereWriter.String()); err != nil { + fromSQL); err != nil { + return 0, err + } + if err := utils.WriteBuilder(updateWriter, whereWriter); err != nil { return 0, err } - updateWriter.Append(whereWriter.Args()...) res, err := session.exec(updateWriter.String(), append(args, updateWriter.Args()...)...) if err != nil {