diff --git a/LICENSE b/LICENSE index 9c1d175d..84d2ae53 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 - 2015 +Copyright (c) 2013 - 2015 The Xorm Authors All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index 25fbc7b2..ddea8f92 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ Xorm is a simple and powerful ORM for Go. +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) + [![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") # Features diff --git a/README_CN.md b/README_CN.md index fb08040b..8ab58ea5 100644 --- a/README_CN.md +++ b/README_CN.md @@ -4,6 +4,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) + [![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) ## 特性 diff --git a/VERSION b/VERSION index af3e71ec..915f9513 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.4.3.0520 +xorm v0.4.3.0824 diff --git a/examples/goroutine.go b/examples/goroutine.go index b18fe4f8..f99c9fbe 100644 --- a/examples/goroutine.go +++ b/examples/goroutine.go @@ -33,7 +33,7 @@ func test(engine *xorm.Engine) { return } - size := 500 + size := 100 queue := make(chan int, size) for i := 0; i < size; i++ { @@ -83,7 +83,7 @@ func test(engine *xorm.Engine) { } func main() { - runtime.GOMAXPROCS(1) + runtime.GOMAXPROCS(2) fmt.Println("-----start sqlite go routines-----") engine, err := sqliteEngine() if err != nil { diff --git a/memroy_store.go b/memory_store.go similarity index 100% rename from memroy_store.go rename to memory_store.go diff --git a/mssql_dialect.go b/mssql_dialect.go index 0eef76d4..41989b91 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -509,6 +509,10 @@ func (db *mssql) CreateTableSql(table *core.Table, tableName, storeEngine, chars return sql } +func (db *mssql) ForUpdateSql(query string) string { + return query +} + func (db *mssql) Filters() []core.Filter { return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} } diff --git a/postgres_dialect.go b/postgres_dialect.go index 67ceecd0..972a214f 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -913,7 +913,8 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - args := []interface{}{tableName} + pgSchema := "public" + args := []interface{}{tableName,pgSchema} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -924,7 +925,7 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` rows, err := db.DB().Query(s, args...) if db.Logger != nil { diff --git a/processors.go b/processors.go index 18071c1f..8f95ae3b 100644 --- a/processors.go +++ b/processors.go @@ -23,6 +23,10 @@ type BeforeSetProcessor interface { BeforeSet(string, Cell) } +type AfterSetProcessor interface { + AfterSet(string, Cell) +} + // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations //// Executed before an object is validated //type BeforeValidateProcessor interface { diff --git a/rows.go b/rows.go index 2d3bbb85..fb18454d 100644 --- a/rows.go +++ b/rows.go @@ -45,7 +45,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable) } - rows.session.Engine.logSQL(sqlStr, args) + rows.session.saveLastSQL(sqlStr, args) var err error rows.stmt, err = rows.session.DB().Prepare(sqlStr) if err != nil { diff --git a/session.go b/session.go index b114f244..0aeb7383 100644 --- a/session.go +++ b/session.go @@ -46,6 +46,10 @@ type Session struct { stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) cascadeDeep int + + // !evalphobia! stored the last executed query on this session + lastSQL string + lastSQLArgs []interface{} } // Method Init reset the session as the init status. @@ -63,6 +67,9 @@ func (session *Session) Init() { session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) session.beforeClosures = make([]func(interface{}), 0) session.afterClosures = make([]func(interface{}), 0) + + session.lastSQL = "" + session.lastSQLArgs = []interface{}{} } // Method Close release the connection from pool @@ -224,6 +231,12 @@ func (session *Session) Distinct(columns ...string) *Session { return session } +// Set Read/Write locking for UPDATE +func (session *Session) ForUpdate() *Session { + session.Statement.IsForUpdate = true + return session +} + // Only not use the paramters as select or update columns func (session *Session) Omit(columns ...string) *Session { session.Statement.Omit(columns...) @@ -331,8 +344,7 @@ func (session *Session) Begin() error { session.IsAutoCommit = false session.IsCommitedOrRollbacked = false session.Tx = tx - - session.Engine.logSQL("BEGIN TRANSACTION") + session.saveLastSQL("BEGIN TRANSACTION") } return nil } @@ -340,7 +352,7 @@ func (session *Session) Begin() error { // When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.logSQL(session.Engine.dialect.RollBackStr()) + session.saveLastSQL(session.Engine.dialect.RollBackStr()) session.IsCommitedOrRollbacked = true return session.Tx.Rollback() } @@ -350,7 +362,7 @@ func (session *Session) Rollback() error { // When using transaction, Commit will commit all operations. func (session *Session) Commit() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.logSQL("COMMIT") + session.saveLastSQL("COMMIT") session.IsCommitedOrRollbacked = true var err error if err = session.Tx.Commit(); err == nil { @@ -471,7 +483,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) } - session.Engine.logSQL(sqlStr, args...) + session.saveLastSQL(sqlStr, args...) return session.Engine.LogSQLExecutionTime(sqlStr, args, func() (sql.Result, error) { if session.IsAutoCommit { @@ -614,11 +626,15 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { return nil } -func (statement *Statement) JoinColumns(cols []*core.Column) string { +func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string { var colnames = make([]string, len(cols)) for i, col := range cols { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) + if includeTableName { + colnames[i] = statement.Engine.Quote(statement.TableName()) + + "." + statement.Engine.Quote(col.Name) + } else { + colnames[i] = statement.Engine.Quote(col.Name) + } } return strings.Join(colnames, ", ") } @@ -630,11 +646,14 @@ func (statement *Statement) convertIdSql(sqlStr string) string { return "" } - colstrs := statement.JoinColumns(cols) - sqls := splitNNoCase(sqlStr, "from", 2) + colstrs := statement.JoinColumns(cols, false) + sqls := splitNNoCase(sqlStr, " from ", 2) if len(sqls) != 2 { return "" } + if statement.Engine.dialect.DBType() == "ql" { + return fmt.Sprintf("SELECT id() FROM %v", sqls[1]) + } return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) } return "" @@ -1463,7 +1482,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 sql := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName)) err := session.DB().QueryRow(sql).Scan(&total) - session.Engine.logSQL(sql) + session.saveLastSQL(sql) if err != nil { return true, err } @@ -1632,6 +1651,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount } } + defer func() { + if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { + for ii, key := range fields { + b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) + } + } + }() + var tempMap = make(map[string]int) for ii, key := range fields { var idx int @@ -1681,7 +1708,6 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount hasAssigned := false switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: if rawValueType.Kind() == reflect.String { hasAssigned = true @@ -1692,6 +1718,15 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount return err } fieldValue.Set(x.Elem()) + } else if rawValueType.Kind() == reflect.Slice { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + session.Engine.LogError(err) + return err + } + fieldValue.Set(x.Elem()) } case reflect.Slice, reflect.Array: switch rawValueType.Kind() { @@ -1736,6 +1771,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount fieldValue.SetUint(uint64(vv.Int())) } case reflect.Struct: + col := table.GetColumn(key) if fieldType.ConvertibleTo(core.TimeType) { if rawValueType == core.TimeType { hasAssigned = true @@ -1743,7 +1779,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount t := vv.Convert(core.TimeType).Interface().(time.Time) z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.Engine.LogDebug("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) + session.Engine.LogDebugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local) } @@ -1765,15 +1801,35 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount // !! 增加支持sql.Scanner接口的结构,如sql.NullString hasAssigned = true if err := nulVal.Scan(vv.Interface()); err != nil { - fmt.Println("sql.Sanner error:", err.Error()) + //fmt.Println("sql.Sanner error:", err.Error()) session.Engine.LogError("sql.Sanner error:", err.Error()) hasAssigned = false } + } else if col.SQLType.IsJson() { + if rawValueType.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + session.Engine.LogError(err) + return err + } + fieldValue.Set(x.Elem()) + } else if rawValueType.Kind() == reflect.Slice { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + session.Engine.LogError(err) + return err + } + fieldValue.Set(x.Elem()) + } } else if session.Statement.UseCascade { table := session.Engine.autoMapType(*fieldValue) if table != nil { - if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + if len(table.PrimaryKeys) != 1 { + panic("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) @@ -1964,7 +2020,7 @@ func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) *sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable) } - session.Engine.logSQL(*sqlStr, paramStr...) + session.saveLastSQL(*sqlStr, paramStr...) } func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { @@ -2954,21 +3010,39 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return tf, nil } - if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { - if len(fieldTable.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) - return pkField.Interface(), nil - } else { - return 0, fmt.Errorf("no primary key for col %v", col.Name) - } - } else { + if !col.SQLType.IsJson() { // !! 增加支持driver.Valuer接口的结构,如sql.NullString if v, ok := fieldValue.Interface().(driver.Valuer); ok { return v.Value() } - return 0, fmt.Errorf("Unsupported type %v", fieldValue.Type()) + fieldTable := session.Engine.autoMapType(fieldValue) + //if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { + if len(fieldTable.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) + return pkField.Interface(), nil + } + return 0, fmt.Errorf("no primary key for col %v", col.Name) + //} } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogError(err) + return 0, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogError(err) + return 0, err + } + return bytes, nil + } + + return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) case reflect.Complex64, reflect.Complex128: bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { @@ -3002,9 +3076,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } } return bytes, nil - } else { - return nil, ErrUnSupportedType } + return nil, ErrUnSupportedType case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: return int64(fieldValue.Uint()), nil default: @@ -3094,9 +3167,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. + if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { + //assert table.AutoIncrement != "" + res, err := session.query("select seq_atable.currval from dual", args...) - if session.Engine.DriverName() != core.POSTGRES || table.AutoIncrement == "" { - res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } else { @@ -3116,14 +3190,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if table.AutoIncrement == "" { - return res.RowsAffected() + if len(res) < 1 { + return 0, errors.New("insert no error but not returned id") } - var id int64 = 0 - id, err = res.LastInsertId() - if err != nil || id <= 0 { - return res.RowsAffected() + idByte := res[0][table.AutoIncrement] + id, err := strconv.ParseInt(string(idByte), 10, 64) + if err != nil { + return 1, err } aiValue, err := table.AutoIncrColumn().ValueOf(bean) @@ -3131,8 +3205,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.Engine.LogError(err) } - if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { - return res.RowsAffected() + if aiValue == nil || !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() { + return 1, nil } var v interface{} = id @@ -3150,8 +3224,8 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } aiValue.Set(reflect.ValueOf(v)) - return res.RowsAffected() - } else { + return 1, nil + } else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { //assert table.AutoIncrement != "" sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) res, err := session.query(sqlStr, args...) @@ -3210,6 +3284,62 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { aiValue.Set(reflect.ValueOf(v)) return 1, nil + } else { + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } else { + handleAfterInsertProcessorFunc(bean) + } + + if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } + + if table.Version != "" && session.Statement.checkVersion { + verValue, err := table.VersionColumn().ValueOf(bean) + if err != nil { + session.Engine.LogError(err) + } else if verValue.IsValid() && verValue.CanSet() { + verValue.SetInt(1) + } + } + + if table.AutoIncrement == "" { + return res.RowsAffected() + } + + var id int64 = 0 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return res.RowsAffected() + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.Engine.LogError(err) + } + + if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { + return res.RowsAffected() + } + + var v interface{} = id + switch aiValue.Type().Kind() { + case reflect.Int32: + v = int32(id) + case reflect.Int: + v = int(id) + case reflect.Uint32: + v = uint32(id) + case reflect.Uint64: + v = uint64(id) + case reflect.Uint: + v = uint(id) + } + aiValue.Set(reflect.ValueOf(v)) + + return res.RowsAffected() } } @@ -3230,7 +3360,7 @@ func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) { return "", "" } - colstrs := statement.JoinColumns(statement.RefTable.PKColumns()) + colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true) sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { @@ -3827,6 +3957,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return res.RowsAffected() } +// saveLastSQL stores executed query information +func (session *Session) saveLastSQL(sql string, args ...interface{}) { + session.lastSQL = sql + session.lastSQLArgs = args + session.Engine.logSQL(sql, args...) +} + +// LastSQL returns last query information +func (session *Session) LastSQL() (string, []interface{}) { + return session.lastSQL, session.lastSQLArgs +} + func (s *Session) Sync2(beans ...interface{}) error { engine := s.Engine diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index 94e7d6b3..80873dbd 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -250,6 +250,10 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } +func (db *sqlite3) ForUpdateSql(query string) string { + return query +} + /*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { args := []interface{}{tableName} sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" diff --git a/statement.go b/statement.go index 5487b19f..13f88b0f 100644 --- a/statement.go +++ b/statement.go @@ -66,6 +66,7 @@ type Statement struct { UseCache bool UseAutoTime bool IsDistinct bool + IsForUpdate bool TableAlias string allUseBool bool checkVersion bool @@ -104,6 +105,7 @@ func (statement *Statement) Init() { statement.UseCache = true statement.UseAutoTime = true statement.IsDistinct = false + statement.IsForUpdate = false statement.TableAlias = "" statement.selectStr = "" statement.allUseBool = false @@ -432,6 +434,9 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } + if col.SQLType.IsJson() { + continue + } var colName string if addedTableName { @@ -534,23 +539,43 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, continue } } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogError(err) continue } - } else { - //TODO: how to handler? - panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogError(err) + continue + } + val = bytes } } else { - val = fieldValue.Interface() + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !isZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) + } + } else { + val = fieldValue.Interface() + } } } case reflect.Array, reflect.Slice, reflect.Map: @@ -794,6 +819,12 @@ func (statement *Statement) Distinct(columns ...string) *Statement { return statement } +// Generate "SELECT ... FOR UPDATE" statment +func (statement *Statement) ForUpdate() *Statement { + statement.IsForUpdate = true + return statement +} + // replace select func (s *Statement) Select(str string) *Statement { s.selectStr = str @@ -810,6 +841,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { if strings.Contains(statement.ColumnStr, ".") { statement.ColumnStr = strings.Replace(statement.ColumnStr, ".", statement.Engine.Quote("."), -1) } + statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.Quote("*"), "*", -1) return statement } @@ -1179,6 +1211,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { distinct = "DISTINCT " } + var dialect core.Dialect = statement.Engine.Dialect() var top string var mssqlCondi string /*var orderBy string @@ -1190,7 +1223,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.WhereStr != "" { whereStr = fmt.Sprintf(" WHERE %v", statement.WhereStr) if statement.ConditionStr != "" { - whereStr = fmt.Sprintf("%v %s %v", whereStr, statement.Engine.Dialect().AndStr(), + whereStr = fmt.Sprintf("%v %s %v", whereStr, dialect.AndStr(), statement.ConditionStr) } } else if statement.ConditionStr != "" { @@ -1198,7 +1231,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) if statement.TableAlias != "" { - if statement.Engine.dialect.DBType() == core.ORACLE { + if dialect.DBType() == core.ORACLE { fromStr += " " + statement.Engine.Quote(statement.TableAlias) } else { fromStr += " AS " + statement.Engine.Quote(statement.TableAlias) @@ -1208,7 +1241,7 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) } - if statement.Engine.dialect.DBType() == core.MSSQL { + if dialect.DBType() == core.MSSQL { if statement.LimitN > 0 { top = fmt.Sprintf(" TOP %d ", statement.LimitN) } @@ -1258,17 +1291,20 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if statement.Start > 0 { a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) } else if statement.LimitN > 0 { a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } - } else if statement.Engine.dialect.DBType() == core.ORACLE { + } else if dialect.DBType() == core.ORACLE { if statement.Start != 0 || statement.LimitN != 0 { a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) } } + if statement.IsForUpdate { + a = dialect.ForUpdateSql(a) + } return } diff --git a/xorm.go b/xorm.go index 8e630b9a..4f4d8938 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( ) const ( - Version string = "0.4.3.0627" + Version string = "0.4.3.0824" ) func regDrvsNDialects() bool {