diff --git a/session.go b/session.go index 2256e7d7..dbcaf195 100644 --- a/session.go +++ b/session.go @@ -649,39 +649,6 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { return nil } -func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string { - var colnames = make([]string, len(cols)) - for i, col := range cols { - if includeTableName { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) - } else { - colnames[i] = statement.Engine.Quote(col.Name) - } - } - return strings.Join(colnames, ", ") -} - -func (statement *Statement) convertIdSql(sqlStr string) string { - if statement.RefTable != nil { - cols := statement.RefTable.PKColumns() - if len(cols) == 0 { - return "" - } - - colstrs := statement.JoinColumns(cols, false) - sqls := splitNNoCase(sqlStr, " from ", 2) - if len(sqls) != 2 { - return "" - } - if statement.Engine.dialect.DBType() == "ql" { - return fmt.Sprintf("SELECT id() FROM %v", sqls[1]) - } - return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) - } - return "" -} - func (session *Session) canCache() bool { if session.Statement.RefTable == nil || session.Statement.JoinStr != "" || @@ -3228,47 +3195,6 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return session.innerInsert(bean) } -func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { - if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { - return "", "" - } - - colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true) - sqls := splitNNoCase(sqlStr, "where", 2) - if len(sqls) != 2 { - if len(sqls) == 1 { - return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - colstrs, statement.Engine.Quote(statement.TableName())) - } - return "", "" - } - - var whereStr = sqls[1] - - //TODO: for postgres only, if any other database? - var paraStr string - if statement.Engine.dialect.DBType() == core.POSTGRES { - paraStr = "$" - } else if statement.Engine.dialect.DBType() == core.MSSQL { - paraStr = ":" - } - - if paraStr != "" { - if strings.Contains(sqls[1], paraStr) { - dollers := strings.Split(sqls[1], paraStr) - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) - } - } - } - - return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - colstrs, statement.Engine.Quote(statement.TableName()), - whereStr) -} - func (session *Session) cacheInsert(tables ...string) error { if session.Statement.RefTable == nil { return ErrCacheFailed diff --git a/statement.go b/statement.go index 3dd5c60c..6bf20e67 100644 --- a/statement.go +++ b/statement.go @@ -1347,3 +1347,77 @@ func (statement *Statement) processIdParam() { } } } + +func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string { + var colnames = make([]string, len(cols)) + for i, col := range cols { + if includeTableName { + colnames[i] = statement.Engine.Quote(statement.TableName()) + + "." + statement.Engine.Quote(col.Name) + } else { + colnames[i] = statement.Engine.Quote(col.Name) + } + } + return strings.Join(colnames, ", ") +} + +func (statement *Statement) convertIdSql(sqlStr string) string { + if statement.RefTable != nil { + cols := statement.RefTable.PKColumns() + if len(cols) == 0 { + return "" + } + + colstrs := statement.JoinColumns(cols, false) + sqls := splitNNoCase(sqlStr, " from ", 2) + if len(sqls) != 2 { + return "" + } + if statement.Engine.dialect.DBType() == "ql" { + return fmt.Sprintf("SELECT id() FROM %v", sqls[1]) + } + return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) + } + return "" +} + +func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { + return "", "" + } + + colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true) + sqls := splitNNoCase(sqlStr, "where", 2) + if len(sqls) != 2 { + if len(sqls) == 1 { + return sqls[0], fmt.Sprintf("SELECT %v FROM %v", + colstrs, statement.Engine.Quote(statement.TableName())) + } + return "", "" + } + + var whereStr = sqls[1] + + //TODO: for postgres only, if any other database? + var paraStr string + if statement.Engine.dialect.DBType() == core.POSTGRES { + paraStr = "$" + } else if statement.Engine.dialect.DBType() == core.MSSQL { + paraStr = ":" + } + + if paraStr != "" { + if strings.Contains(sqls[1], paraStr) { + dollers := strings.Split(sqls[1], paraStr) + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) + } + } + } + + return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", + colstrs, statement.Engine.Quote(statement.TableName()), + whereStr) +}