diff --git a/.gitignore b/.gitignore index b698bc6f..22486d5f 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,5 @@ temp_test.go .vscode xorm.test *.sqlite3 + +.idea/ diff --git a/README.md b/README.md index ffb6fc85..0ba5f040 100644 --- a/README.md +++ b/README.md @@ -217,7 +217,7 @@ has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() ```Go var users []User err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users) -// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10 +// SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0 type Detail struct { Id int64 @@ -234,7 +234,7 @@ err := engine.Table("user").Select("user.*, detail.*"). Join("INNER", "detail", "detail.user_id = user.id"). Where("user.name = ?", name).Limit(10, 0). Find(&users) -// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 +// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0 ``` * `Iterate` and `Rows` query multiple records and record by record handle, there are two methods Iterate and Rows @@ -265,7 +265,7 @@ for rows.Next() { * `Update` update one or more records, default will update non-empty and non-zero fields except when you use Cols, AllCols and so on. ```Go -affected, err := engine.Id(1).Update(&user) +affected, err := engine.ID(1).Update(&user) // UPDATE user SET ... Where id = ? affected, err := engine.Update(&user, &User{Name:name}) @@ -276,14 +276,14 @@ affected, err := engine.In("id", ids).Update(&user) // UPDATE user SET ... Where id IN (?, ?, ?) // force update indicated columns by Cols -affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12}) +affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12}) // UPDATE user SET age = ?, updated=? Where id = ? // force NOT update indicated columns by Omit -affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12}) +affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12}) // UPDATE user SET age = ?, updated=? Where id = ? -affected, err := engine.Id(1).AllCols().Update(&user) +affected, err := engine.ID(1).AllCols().Update(&user) // UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? ``` diff --git a/README_CN.md b/README_CN.md index 71b30c8e..1781a69b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -224,7 +224,7 @@ has, err = testEngine.Table("record_exist").Where("name = ?", "test1").Exist() ```Go var users []User err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users) -// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10 +// SELECT * FROM user WHERE name = ? AND age > 10 limit 10 offset 0 type Detail struct { Id int64 @@ -241,7 +241,7 @@ err := engine.Table("user").Select("user.*, detail.*") Join("INNER", "detail", "detail.user_id = user.id"). Where("user.name = ?", name).Limit(10, 0). Find(&users) -// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 +// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 10 offset 0 ``` * `Iterate` 和 `Rows` 根据条件遍历数据库,可以有两种方式: Iterate and Rows @@ -272,7 +272,7 @@ for rows.Next() { * `Update` 更新数据,除非使用Cols,AllCols函数指明,默认只更新非空和非0的字段 ```Go -affected, err := engine.Id(1).Update(&user) +affected, err := engine.ID(1).Update(&user) // UPDATE user SET ... Where id = ? affected, err := engine.Update(&user, &User{Name:name}) @@ -283,14 +283,14 @@ affected, err := engine.In(ids).Update(&user) // UPDATE user SET ... Where id IN (?, ?, ?) // force update indicated columns by Cols -affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12}) +affected, err := engine.ID(1).Cols("age").Update(&User{Name:name, Age: 12}) // UPDATE user SET age = ?, updated=? Where id = ? // force NOT update indicated columns by Omit -affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12}) +affected, err := engine.ID(1).Omit("name").Update(&User{Name:name, Age: 12}) // UPDATE user SET age = ?, updated=? Where id = ? -affected, err := engine.Id(1).AllCols().Update(&user) +affected, err := engine.ID(1).AllCols().Update(&user) // UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? ``` diff --git a/dialect_postgres.go b/dialect_postgres.go index 83e9a101..2b2a0b78 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -764,6 +764,9 @@ var ( "YES": true, "ZONE": true, } + + // DefaultPostgresSchema default postgres schema + DefaultPostgresSchema = "public" ) type postgres struct { @@ -771,7 +774,14 @@ type postgres struct { } func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(d, db, uri, drivername, dataSourceName) + err := db.Base.Init(d, db, uri, drivername, dataSourceName) + if err != nil { + return err + } + if db.Schema == "" { + db.Schema = DefaultPostgresSchema + } + return nil } func (db *postgres) SqlType(c *core.Column) string { @@ -868,29 +878,35 @@ func (db *postgres) IndexOnTable() bool { } func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{tableName, idxName} + if len(db.Schema) == 0 { + args := []interface{}{tableName, idxName} + return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args + } + + args := []interface{}{db.Schema, tableName, idxName} return `SELECT indexname FROM pg_indexes ` + - `WHERE tablename = ? AND indexname = ?`, args + `WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args } func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + if len(db.Schema) == 0 { + args := []interface{}{tableName} + return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + } + args := []interface{}{db.Schema, tableName} + return `SELECT tablename FROM pg_tables WHERE schemaname = ? AND tablename = ?`, args } -/*func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args -}*/ - func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { - return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", - tableName, col.Name, db.SqlType(col)) + if len(db.Schema) == 0 { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) + } + return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", + db.Schema, tableName, col.Name, db.SqlType(col)) } func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { - //var unique string quote := db.Quote idxName := index.Name @@ -906,9 +922,14 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { } func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { - args := []interface{}{tableName, colName} - query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + - " AND column_name = $2" + args := []interface{}{db.Schema, tableName, colName} + query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + + " AND column_name = $3" + if len(db.Schema) == 0 { + args = []interface{}{tableName, colName} + query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + + " AND column_name = $2" + } db.LogSQL(query, args) rows, err := db.DB().Query(query, args...) @@ -921,8 +942,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { - // FIXME: the schema should be replaced by user custom's - args := []interface{}{tableName, "public"} + args := []interface{}{tableName} s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey @@ -933,7 +953,15 @@ FROM pg_attribute f LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) LEFT JOIN pg_class AS g ON p.confrelid = g.oid LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name -WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` +WHERE c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` + + var f string + if len(db.Schema) != 0 { + args = append(args, db.Schema) + f = "AND s.table_schema = $2" + } + s = fmt.Sprintf(s, f) + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1023,9 +1051,13 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att } func (db *postgres) GetTables() ([]*core.Table, error) { - // FIXME: replace public to user customrize schema - args := []interface{}{"public"} - s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1") + args := []interface{}{} + s := "SELECT tablename FROM pg_tables" + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " WHERE schemaname = $1" + } + db.LogSQL(s, args) rows, err := db.DB().Query(s, args...) @@ -1049,10 +1081,13 @@ func (db *postgres) GetTables() ([]*core.Table, error) { } func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { - // FIXME: replace the public schema to user specify schema - args := []interface{}{"public", tableName} - s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2") + args := []interface{}{tableName} + s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") db.LogSQL(s, args) + if len(db.Schema) != 0 { + args = append(args, db.Schema) + s = s + " AND schemaname=$2" + } rows, err := db.DB().Query(s, args...) if err != nil { diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 0da7fc77..2ee1e2f3 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -7,11 +7,7 @@ import ( "github.com/go-xorm/core" ) -func TestPostgresDialect(t *testing.T) { - TestParse(t) -} - -func TestParse(t *testing.T) { +func TestParsePostgres(t *testing.T) { tests := []struct { in string expected string @@ -20,10 +16,10 @@ func TestParse(t *testing.T) { {"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true}, {"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true}, {"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false}, - {"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true}, - {"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true}, + //{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true}, + //{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true}, {"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true}, - {"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, + //{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, {"dbname=db sslmode=disable", "db", true}, {"user=auser password=password dbname=db sslmode=disable", "db", true}, {"", "db", false}, diff --git a/engine.go b/engine.go index 444611af..52ec1e3f 100644 --- a/engine.go +++ b/engine.go @@ -1453,6 +1453,13 @@ func (engine *Engine) Find(beans interface{}, condiBeans ...interface{}) error { return session.Find(beans, condiBeans...) } +// FindAndCount find the results and also return the counts +func (engine *Engine) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { + session := engine.NewSession() + defer session.Close() + return session.FindAndCount(rowsSlicePtr, condiBean...) +} + // Iterate record by record handle records from table, bean's non-empty fields // are conditions. func (engine *Engine) Iterate(bean interface{}, fun IterFunc) error { diff --git a/interface.go b/interface.go index 9a3b6da0..85a46a27 100644 --- a/interface.go +++ b/interface.go @@ -30,6 +30,7 @@ type Interface interface { Exec(string, ...interface{}) (sql.Result, error) Exist(bean ...interface{}) (bool, error) Find(interface{}, ...interface{}) error + FindAndCount(interface{}, ...interface{}) (int64, error) Get(interface{}) (bool, error) GroupBy(keys string) *Session ID(interface{}) *Session diff --git a/migrate/migrate.go b/migrate/migrate.go index 6c2a13a8..cca4e523 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -122,7 +122,10 @@ func (m *Migrate) RollbackLast() error { func (m *Migrate) getLastRunnedMigration() (*Migration, error) { for i := len(m.migrations) - 1; i >= 0; i-- { migration := m.migrations[i] - if m.migrationDidRun(migration) { + run, err := m.migrationDidRun(migration) + if err != nil { + return nil, err + } else if run { return migration, nil } } @@ -165,7 +168,12 @@ func (m *Migrate) runMigration(migration *Migration) error { return ErrMissingID } - if !m.migrationDidRun(migration) { + run, err :=m.migrationDidRun(migration) + if err != nil { + return err + } + + if !run { if err := migration.Migrate(m.db); err != nil { return err } @@ -193,11 +201,9 @@ func (m *Migrate) createMigrationTableIfNotExists() error { return nil } -func (m *Migrate) migrationDidRun(mig *Migration) bool { - row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID) - var count int - row.Scan(&count) - return count > 0 +func (m *Migrate) migrationDidRun(mig *Migration) (bool, error) { + count, err := m.db.SQL(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID).Count() + return count > 0, err } func (m *Migrate) isFirstRun() bool { diff --git a/session_exist.go b/session_exist.go index 049c1ddf..378a6483 100644 --- a/session_exist.go +++ b/session_exist.go @@ -10,6 +10,7 @@ import ( "reflect" "github.com/go-xorm/builder" + "github.com/go-xorm/core" ) // Exist returns true if the record exist otherwise return false @@ -35,10 +36,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { return false, err } - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) + if session.engine.dialect.DBType() == core.MSSQL { + sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s WHERE %s", tableName, condSQL) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) + } args = condArgs } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) + if session.engine.dialect.DBType() == core.MSSQL { + sqlStr = fmt.Sprintf("SELECT top 1 * FROM %s", tableName) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) + } args = []interface{}{} } } else { diff --git a/session_find.go b/session_find.go index f95dcfef..f9b3777f 100644 --- a/session_find.go +++ b/session_find.go @@ -29,6 +29,36 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return session.find(rowsSlicePtr, condiBean...) } +// FindAndCount find the results and also return the counts +func (session *Session) FindAndCount(rowsSlicePtr interface{}, condiBean ...interface{}) (int64, error) { + if session.isAutoClose { + defer session.Close() + } + + session.autoResetStatement = false + err := session.find(rowsSlicePtr, condiBean...) + if err != nil { + return 0, err + } + + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { + return 0, errors.New("needs a pointer to a slice or a map") + } + + sliceElementType := sliceValue.Type().Elem() + if sliceElementType.Kind() == reflect.Ptr { + sliceElementType = sliceElementType.Elem() + } + session.autoResetStatement = true + + if session.statement.selectStr != "" { + session.statement.selectStr = "" + } + + return session.Count(reflect.New(sliceElementType).Interface()) +} + func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) error { sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) if sliceValue.Kind() != reflect.Slice && sliceValue.Kind() != reflect.Map { @@ -128,7 +158,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } args = append(session.statement.joinArgs, condArgs...) - sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL) + sqlStr, err = session.statement.genSelectSQL(columnStr, condSQL, true) if err != nil { return err } diff --git a/session_find_test.go b/session_find_test.go index 393e4621..20c15362 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -488,3 +488,84 @@ func TestFindBit(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, len(results)) } + +func TestFindMark(t *testing.T) { + + type Mark struct { + Mark1 string `xorm:"VARCHAR(1)"` + Mark2 string `xorm:"VARCHAR(1)"` + MarkA string `xorm:"VARCHAR(1)"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(Mark)) + + cnt, err := testEngine.Insert([]Mark{ + { + Mark1: "1", + Mark2: "2", + MarkA: "A", + }, + { + Mark1: "1", + Mark2: "2", + MarkA: "A", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + var results = make([]Mark, 0, 2) + err = testEngine.Find(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(results)) +} + +func TestFindAndCountOneFunc(t *testing.T) { + type FindAndCountStruct struct { + Id int64 + Content string + Msg bool `xorm:"bit"` + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(FindAndCountStruct)) + + cnt, err := testEngine.Insert([]FindAndCountStruct{ + { + Content: "111", + Msg: false, + }, + { + Content: "222", + Msg: true, + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + var results = make([]FindAndCountStruct, 0, 2) + cnt, err = testEngine.FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(results)) + assert.EqualValues(t, 2, cnt) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("msg = ?", true).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 1, cnt) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("msg = ?", true).Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 1, cnt) + + results = make([]FindAndCountStruct, 0, 1) + cnt, err = testEngine.Where("msg = ?", true).Select("id, content, msg"). + Limit(1).FindAndCount(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + assert.EqualValues(t, 1, cnt) +} diff --git a/session_get.go b/session_get.go index 8faf53c0..68b37af7 100644 --- a/session_get.go +++ b/session_get.go @@ -5,6 +5,7 @@ package xorm import ( + "database/sql" "errors" "reflect" "strconv" @@ -79,6 +80,13 @@ func (session *Session) nocacheGet(beanKind reflect.Kind, table *core.Table, bea return false, nil } + switch bean.(type) { + case sql.NullInt64, sql.NullBool, sql.NullFloat64, sql.NullString: + return true, rows.Scan(&bean) + case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString: + return true, rows.Scan(bean) + } + switch beanKind { case reflect.Struct: fields, err := rows.Columns() diff --git a/session_get_test.go b/session_get_test.go index 571b8077..61398d1f 100644 --- a/session_get_test.go +++ b/session_get_test.go @@ -5,6 +5,7 @@ package xorm import ( + "database/sql" "fmt" "testing" "time" @@ -31,7 +32,7 @@ func TestGetVar(t *testing.T) { Age: 28, Money: 1.5, } - _, err := testEngine.InsertOne(data) + _, err := testEngine.InsertOne(&data) assert.NoError(t, err) var msg string @@ -55,6 +56,27 @@ func TestGetVar(t *testing.T) { assert.Equal(t, true, has) assert.EqualValues(t, 28, age2) + var id sql.NullInt64 + has, err = testEngine.Table("get_var").Cols("id").Get(&id) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, true, id.Valid) + assert.EqualValues(t, data.Id, id.Int64) + + var msgNull sql.NullString + has, err = testEngine.Table("get_var").Cols("msg").Get(&msgNull) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, true, msgNull.Valid) + assert.EqualValues(t, data.Msg, msgNull.String) + + var nullMoney sql.NullFloat64 + has, err = testEngine.Table("get_var").Cols("money").Get(&nullMoney) + assert.NoError(t, err) + assert.Equal(t, true, has) + assert.Equal(t, true, nullMoney.Valid) + assert.EqualValues(t, data.Money, nullMoney.Float64) + var money float64 has, err = testEngine.Table("get_var").Cols("money").Get(&money) assert.NoError(t, err) diff --git a/session_insert_test.go b/session_insert_test.go index 05a3d3bf..080ace9b 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -662,3 +662,81 @@ func TestInsertCreatedInt64(t *testing.T) { assert.EqualValues(t, data.Created, data2.Created) } + +type MyUserinfo Userinfo + +func (MyUserinfo) TableName() string { + return "user_info" +} + +func TestInsertMulti3(t *testing.T) { + assert.NoError(t, prepareEngine()) + + testEngine.ShowSQL(true) + assertSync(t, new(MyUserinfo)) + + users := []MyUserinfo{ + {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + cnt, err := testEngine.Insert(&users) + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) + + users2 := []*MyUserinfo{ + &MyUserinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &MyUserinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + &MyUserinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &MyUserinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + + cnt, err = testEngine.Insert(&users2) + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) +} + +type MyUserinfo2 struct { + Uid int64 `xorm:"id pk not null autoincr"` + Username string `xorm:"unique"` + Departname string + Alias string `xorm:"-"` + Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool +} + +func (MyUserinfo2) TableName() string { + return "user_info" +} + +func TestInsertMulti4(t *testing.T) { + assert.NoError(t, prepareEngine()) + + testEngine.ShowSQL(true) + assertSync(t, new(MyUserinfo2)) + + users := []MyUserinfo2{ + {Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + {Username: "xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + {Username: "xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + cnt, err := testEngine.Insert(&users) + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) + + users2 := []*MyUserinfo2{ + &MyUserinfo2{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &MyUserinfo2{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + &MyUserinfo2{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &MyUserinfo2{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + + cnt, err = testEngine.Insert(&users2) + assert.NoError(t, err) + assert.EqualValues(t, len(users), cnt) +} diff --git a/session_query.go b/session_query.go index 5b4e0dc4..e8fbd8d3 100644 --- a/session_query.go +++ b/session_query.go @@ -17,7 +17,17 @@ import ( func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interface{}, error) { if len(sqlorArgs) > 0 { - return sqlorArgs[0].(string), sqlorArgs[1:], nil + switch sqlorArgs[0].(type) { + case string: + return sqlorArgs[0].(string), sqlorArgs[1:], nil + case *builder.Builder: + return sqlorArgs[0].(*builder.Builder).ToSQL() + case builder.Builder: + bd := sqlorArgs[0].(builder.Builder) + return bd.ToSQL() + default: + return "", nil, ErrUnSupportedType + } } if session.statement.RawSQL != "" { @@ -60,7 +70,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa } args := append(session.statement.joinArgs, condArgs...) - sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL) + sqlStr, err := session.statement.genSelectSQL(columnStr, condSQL, true) if err != nil { return "", nil, err } diff --git a/session_query_test.go b/session_query_test.go index e84a7142..1e556453 100644 --- a/session_query_test.go +++ b/session_query_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/go-xorm/builder" + "github.com/stretchr/testify/assert" ) @@ -183,3 +185,48 @@ func TestQueryNoParams(t *testing.T) { assert.NoError(t, err) assertResult(t, results) } + +func TestQueryWithBuilder(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type QueryWithBuilder struct { + Id int64 `xorm:"autoincr pk"` + Msg string `xorm:"varchar(255)"` + Age int + Money float32 + Created time.Time `xorm:"created"` + } + + testEngine.ShowSQL(true) + + assert.NoError(t, testEngine.Sync2(new(QueryWithBuilder))) + + var q = QueryWithBuilder{ + Msg: "message", + Age: 20, + Money: 3000, + } + cnt, err := testEngine.Insert(&q) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + assertResult := func(t *testing.T, results []map[string][]byte) { + assert.EqualValues(t, 1, len(results)) + id, err := strconv.ParseInt(string(results[0]["id"]), 10, 64) + assert.NoError(t, err) + assert.EqualValues(t, 1, id) + assert.Equal(t, "message", string(results[0]["msg"])) + + age, err := strconv.Atoi(string(results[0]["age"])) + assert.NoError(t, err) + assert.EqualValues(t, 20, age) + + money, err := strconv.ParseFloat(string(results[0]["money"]), 32) + assert.NoError(t, err) + assert.EqualValues(t, 3000, money) + } + + results, err := testEngine.Query(builder.Select("*").From("query_with_builder")) + assert.NoError(t, err) + assertResult(t, results) +} \ No newline at end of file diff --git a/session_schema.go b/session_schema.go index a2708b73..9d9edca8 100644 --- a/session_schema.go +++ b/session_schema.go @@ -255,6 +255,12 @@ func (session *Session) Sync2(beans ...interface{}) error { return err } + session.autoResetStatement = false + defer func() { + session.autoResetStatement = true + session.resetStatement() + }() + var structTables []*core.Table for _, bean := range beans { diff --git a/session_schema_test.go b/session_schema_test.go index fa2fa7eb..712f8a04 100644 --- a/session_schema_test.go +++ b/session_schema_test.go @@ -217,3 +217,51 @@ func TestCharst(t *testing.T) { panic(err) } } + +func TestSync2_1(t *testing.T) { + type WxTest struct { + Id int `xorm:"not null pk autoincr INT(64)` + Passport_user_type int16 `xorm:"null int"` + Id_delete int8 `xorm:"null int default 1"` + } + + assert.NoError(t, prepareEngine()) + + assert.NoError(t, testEngine.DropTables("wx_test")) + assert.NoError(t, testEngine.Sync2(new(WxTest))) + assert.NoError(t, testEngine.Sync2(new(WxTest))) +} + +func TestUnique_1(t *testing.T) { + type UserUnique struct { + Id int64 + UserName string `xorm:"unique varchar(25) not null"` + Password string `xorm:"varchar(255) not null"` + Admin bool `xorm:"not null"` + CreatedAt time.Time `xorm:"created"` + UpdatedAt time.Time `xorm:"updated"` + } + + assert.NoError(t, prepareEngine()) + + assert.NoError(t, testEngine.DropTables("user_unique")) + assert.NoError(t, testEngine.Sync2(new(UserUnique))) + + assert.NoError(t, testEngine.DropTables("user_unique")) + assert.NoError(t, testEngine.CreateTables(new(UserUnique))) + assert.NoError(t, testEngine.CreateUniques(new(UserUnique))) +} + +func TestSync2_2(t *testing.T) { + type TestSync2Index struct { + Id int64 + UserId int64 `xorm:"index"` + } + + assert.NoError(t, prepareEngine()) + + for i := 0; i < 10; i++ { + tableName := fmt.Sprintf("test_sync2_index_%d", i) + assert.NoError(t, testEngine.Table(tableName).Sync2(new(TestSync2Index))) + } +} diff --git a/statement.go b/statement.go index 6400425b..35c4a472 100644 --- a/statement.go +++ b/statement.go @@ -988,7 +988,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, return "", nil, err } - sqlStr, err := statement.genSelectSQL(columnStr, condSQL) + sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true) if err != nil { return "", nil, err } @@ -1018,7 +1018,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa selectSQL = "count(*)" } } - sqlStr, err := statement.genSelectSQL(selectSQL, condSQL) + sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false) if err != nil { return "", nil, err } @@ -1043,7 +1043,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return "", nil, err } - sqlStr, err := statement.genSelectSQL(sumSelect, condSQL) + sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true) if err != nil { return "", nil, err } @@ -1051,7 +1051,7 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, err error) { +func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit bool) (a string, err error) { var distinct string if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " @@ -1149,15 +1149,17 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string, e if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { - if statement.Start > 0 { - a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) - } else if statement.LimitN > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) - } - } else if dialect.DBType() == core.ORACLE { - if statement.Start != 0 || statement.LimitN != 0 { - a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) + if needLimit { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { + if statement.Start > 0 { + a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) + } else if statement.LimitN > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) + } + } else if dialect.DBType() == core.ORACLE { + if statement.Start != 0 || statement.LimitN != 0 { + a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) + } } } if statement.IsForUpdate {