Merge branch 'master' into lunny/fix_859

This commit is contained in:
Lunny Xiao 2022-04-23 17:20:38 +08:00
commit ef088ddb5e
12 changed files with 391 additions and 101 deletions

View File

@ -3,6 +3,32 @@
This changelog goes through all the changes that have been made in each release This changelog goes through all the changes that have been made in each release
without substantial changes to our git log. without substantial changes to our git log.
## [1.3.0](https://gitea.com/xorm/xorm/releases/tag/1.3.0) - 2022-04-14
* BREAKING
* New Prepare useage (#2061)
* Make Get and Rows.Scan accept multiple parameters (#2029)
* Drop sync function and rename sync2 to sync (#2018)
* FEATURES
* Add dameng support (#2007)
* BUGFIXES
* bugfix :Oid It's a special index. You can't put it in (#2105)
* Fix new-lined query execution in master DB node. (#2066)
* Fix bug of Rows (#2048)
* Fix bug (#2046)
* fix panic when `Iterate()` fails (#2040)
* fix panic when convert sql and args with nil time.Time pointer (#2038)
* ENHANCEMENTS
* Fix to add session.statement.IsForUpdate check in Session.queryRows() (#2064)
* Expose ScanString / ScanInterface and etc (#2039)
* TESTING
* Add test for mysql tls (#2049)
* BUILD
* Upgrade dependencies modules (#2078)
* MISC
* Fix oracle keyword AS (#2109)
* Some performance optimization for get (#2043)
## [1.2.2](https://gitea.com/xorm/xorm/releases/tag/1.2.2) - 2021-08-11 ## [1.2.2](https://gitea.com/xorm/xorm/releases/tag/1.2.2) - 2021-08-11
* MISC * MISC

View File

@ -399,7 +399,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
"(SUBSTRING_INDEX(SUBSTRING(VERSION(), 4), '.', 1) = 2 && " + "(SUBSTRING_INDEX(SUBSTRING(VERSION(), 4), '.', 1) = 2 && " +
"SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))" "SUBSTRING_INDEX(SUBSTRING(VERSION(), 6), '-', 1) >= 7)))))"
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, " + " `COLUMN_KEY`, `EXTRA`, `COLUMN_COMMENT`, `CHARACTER_MAXIMUM_LENGTH`, " +
alreadyQuoted + " AS NEEDS_QUOTE " + alreadyQuoted + " AS NEEDS_QUOTE " +
"FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + "FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" +
" ORDER BY `COLUMNS`.ORDINAL_POSITION" " ORDER BY `COLUMNS`.ORDINAL_POSITION"
@ -418,8 +418,8 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
var columnName, nullableStr, colType, colKey, extra, comment string var columnName, nullableStr, colType, colKey, extra, comment string
var alreadyQuoted, isUnsigned bool var alreadyQuoted, isUnsigned bool
var colDefault *string var colDefault, maxLength *string
err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &alreadyQuoted) err = rows.Scan(&columnName, &nullableStr, &colDefault, &colType, &colKey, &extra, &comment, &maxLength, &alreadyQuoted)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -478,6 +478,14 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} }
} }
} }
} else {
switch colType {
case "MEDIUMTEXT", "LONGTEXT", "TEXT":
len1, err = strconv.Atoi(*maxLength)
if err != nil {
return nil, nil, err
}
}
} }
if isUnsigned { if isUnsigned {
colType = "UNSIGNED " + colType colType = "UNSIGNED " + colType

View File

@ -1354,6 +1354,14 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
commentSQL += fmt.Sprintf("COMMENT ON TABLE %s IS '%s'", quoter.Quote(tableName), table.Comment) commentSQL += fmt.Sprintf("COMMENT ON TABLE %s IS '%s'", quoter.Quote(tableName), table.Comment)
} }
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if len(col.Comment) > 0 {
commentSQL += fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoter.Quote(tableName), quoter.Quote(col.Name), col.Comment)
}
}
return createTableSQL + commentSQL, true, nil return createTableSQL + commentSQL, true, nil
} }

View File

@ -255,33 +255,31 @@ func TestDBVersion(t *testing.T) {
fmt.Println(testEngine.Dialect().URI().DBType, "version is", version) fmt.Println(testEngine.Dialect().URI().DBType, "version is", version)
} }
func TestGetColumns(t *testing.T) { func TestGetColumnsComment(t *testing.T) {
if testEngine.Dialect().URI().DBType != schemas.POSTGRES { switch testEngine.Dialect().URI().DBType {
case schemas.POSTGRES, schemas.MYSQL:
default:
t.Skip() t.Skip()
return return
} }
comment := "this is a comment"
type TestCommentStruct struct { type TestCommentStruct struct {
HasComment int HasComment int `xorm:"comment('this is a comment')"`
NoComment int NoComment int
} }
assertSync(t, new(TestCommentStruct)) assertSync(t, new(TestCommentStruct))
comment := "this is a comment"
sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment)
_, err := testEngine.Exec(sql)
assert.NoError(t, err)
tables, err := testEngine.DBMetas() tables, err := testEngine.DBMetas()
assert.NoError(t, err) assert.NoError(t, err)
tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct") tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct")
var hasComment, noComment string var hasComment, noComment string
for _, table := range tables { for _, table := range tables {
if table.Name == tableName { if table.Name == tableName {
col := table.GetColumn("has_comment") col := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("HasComment"))
assert.NotNil(t, col) assert.NotNil(t, col)
hasComment = col.Comment hasComment = col.Comment
col2 := table.GetColumn("no_comment") col2 := table.GetColumn(testEngine.GetColumnMapper().Obj2Table("NoComment"))
assert.NotNil(t, col2) assert.NotNil(t, col2)
noComment = col2.Comment noComment = col2.Comment
break break
@ -290,3 +288,36 @@ func TestGetColumns(t *testing.T) {
assert.Equal(t, comment, hasComment) assert.Equal(t, comment, hasComment)
assert.Zero(t, noComment) assert.Zero(t, noComment)
} }
func TestGetColumnsLength(t *testing.T) {
var max_length int
switch testEngine.Dialect().URI().DBType {
case
schemas.POSTGRES:
max_length = 0
case
schemas.MYSQL:
max_length = 65535
default:
t.Skip()
return
}
type TestLengthStringStruct struct {
Content string `xorm:"TEXT NOT NULL"`
}
assertSync(t, new(TestLengthStringStruct))
tables, err := testEngine.DBMetas()
assert.NoError(t, err)
tableLengthStringName := testEngine.GetColumnMapper().Obj2Table("TestLengthStringStruct")
for _, table := range tables {
if table.Name == tableLengthStringName {
col := table.GetColumn("content")
assert.Equal(t, col.Length, max_length)
assert.Zero(t, col.Length2)
break
}
}
}

View File

@ -70,7 +70,7 @@ func TestRows(t *testing.T) {
} }
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var tbName = testEngine.Quote(testEngine.TableName(user, true)) tbName := testEngine.Quote(testEngine.TableName(user, true))
rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows)) rows2, err := testEngine.SQL("SELECT * FROM " + tbName).Rows(new(UserRows))
assert.NoError(t, err) assert.NoError(t, err)
defer rows2.Close() defer rows2.Close()
@ -92,7 +92,7 @@ func TestRowsMyTableName(t *testing.T) {
IsMan bool IsMan bool
} }
var tableName = "user_rows_my_table_name" tableName := "user_rows_my_table_name"
assert.NoError(t, testEngine.Table(tableName).Sync(new(UserRowsMyTable))) assert.NoError(t, testEngine.Table(tableName).Sync(new(UserRowsMyTable)))
@ -206,3 +206,75 @@ func TestRowsScanVars(t *testing.T) {
assert.NoError(t, rows.Err()) assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
} }
func TestRowsScanBytes(t *testing.T) {
type RowsScanBytes struct {
Id int64
Bytes1 []byte
Bytes2 []byte
}
assert.NoError(t, PrepareEngine())
assert.NoError(t, testEngine.Sync(new(RowsScanBytes)))
cnt, err := testEngine.Insert(&RowsScanBytes{
Bytes1: []byte("bytes1"),
Bytes2: []byte("bytes2"),
}, &RowsScanBytes{
Bytes1: []byte("bytes1-1"),
Bytes2: []byte("bytes2-2"),
})
assert.NoError(t, err)
assert.EqualValues(t, 2, cnt)
{
rows, err := testEngine.Cols("bytes1, bytes2").Rows(new(RowsScanBytes))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
var bytes1 []byte
var bytes2 []byte
for rows.Next() {
err = rows.Scan(&bytes1, &bytes2)
assert.NoError(t, err)
if cnt == 0 {
assert.EqualValues(t, []byte("bytes1"), bytes1)
assert.EqualValues(t, []byte("bytes2"), bytes2)
} else if cnt == 1 {
// bytes1 now should be `bytes1` but will be override
assert.EqualValues(t, []byte("bytes1-1"), bytes1)
assert.EqualValues(t, []byte("bytes2-2"), bytes2)
}
cnt++
}
assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt)
rows.Close()
}
{
rows, err := testEngine.Cols("bytes1, bytes2").Rows(new(RowsScanBytes))
assert.NoError(t, err)
defer rows.Close()
cnt = 0
var rsb RowsScanBytes
for rows.Next() {
err = rows.Scan(&rsb)
assert.NoError(t, err)
if cnt == 0 {
assert.EqualValues(t, []byte("bytes1"), rsb.Bytes1)
assert.EqualValues(t, []byte("bytes2"), rsb.Bytes2)
} else if cnt == 1 {
// bytes1 now should be `bytes1` but will be override
assert.EqualValues(t, []byte("bytes1-1"), rsb.Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), rsb.Bytes2)
}
cnt++
}
assert.NoError(t, rows.Err())
assert.EqualValues(t, 2, cnt)
rows.Close()
}
}

View File

@ -40,14 +40,14 @@ func TestJoinLimit(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var checklist = CheckList{ checklist := CheckList{
Eid: emp.Id, Eid: emp.Id,
} }
cnt, err = testEngine.Insert(&checklist) cnt, err = testEngine.Insert(&checklist)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var salary = Salary{ salary := Salary{
Lid: checklist.Id, Lid: checklist.Id,
} }
cnt, err = testEngine.Insert(&salary) cnt, err = testEngine.Insert(&salary)
@ -89,7 +89,7 @@ func TestFind(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
users2 := make([]Userinfo, 0) users2 := make([]Userinfo, 0)
var tbName = testEngine.Quote(testEngine.TableName(new(Userinfo), true)) tbName := testEngine.Quote(testEngine.TableName(new(Userinfo), true))
err = testEngine.SQL("select * from " + tbName).Find(&users2) err = testEngine.SQL("select * from " + tbName).Find(&users2)
assert.NoError(t, err) assert.NoError(t, err)
} }
@ -119,7 +119,7 @@ func (TeamUser) TableName() string {
} }
func TestFind3(t *testing.T) { func TestFind3(t *testing.T) {
var teamUser = new(TeamUser) teamUser := new(TeamUser)
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
err := testEngine.Sync(new(Team), teamUser) err := testEngine.Sync(new(Team), teamUser)
assert.NoError(t, err) assert.NoError(t, err)
@ -426,7 +426,7 @@ func TestFindBool(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]FindBoolStruct, 0, 2) results := make([]FindBoolStruct, 0, 2)
err = testEngine.Find(&results) err = testEngine.Find(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, len(results)) assert.EqualValues(t, 2, len(results))
@ -457,7 +457,7 @@ func TestFindMark(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]Mark, 0, 2) results := make([]Mark, 0, 2)
err = testEngine.Find(&results) err = testEngine.Find(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, len(results)) assert.EqualValues(t, 2, len(results))
@ -486,7 +486,7 @@ func TestFindAndCountOneFunc(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 2, cnt) assert.EqualValues(t, 2, cnt)
var results = make([]FindAndCountStruct, 0, 2) results := make([]FindAndCountStruct, 0, 2)
cnt, err = testEngine.Limit(1).FindAndCount(&results) cnt, err = testEngine.Limit(1).FindAndCount(&results)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, len(results)) assert.EqualValues(t, 1, len(results))
@ -611,14 +611,14 @@ func TestFindAndCount2(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel)) assertSync(t, new(TestFindAndCountUser), new(TestFindAndCountHotel))
var u = TestFindAndCountUser{ u := TestFindAndCountUser{
Name: "myname", Name: "myname",
} }
cnt, err := testEngine.Insert(&u) cnt, err := testEngine.Insert(&u)
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
var hotel = TestFindAndCountHotel{ hotel := TestFindAndCountHotel{
Name: "myhotel", Name: "myhotel",
Code: "111", Code: "111",
Region: "222", Region: "222",
@ -1063,7 +1063,7 @@ func TestUpdateFind(t *testing.T) {
session := testEngine.NewSession() session := testEngine.NewSession()
defer session.Close() defer session.Close()
var tuf = TestUpdateFind{ tuf := TestUpdateFind{
Name: "test", Name: "test",
} }
_, err := session.Insert(&tuf) _, err := session.Insert(&tuf)
@ -1095,7 +1095,7 @@ func TestFindAnonymousStruct(t *testing.T) {
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
assert.NoError(t, err) assert.NoError(t, err)
var findRes = make([]struct { findRes := make([]struct {
Id int64 Id int64
Name string Name string
}, 0) }, 0)
@ -1115,3 +1115,75 @@ func TestFindAnonymousStruct(t *testing.T) {
assert.EqualValues(t, 1, findRes[0].Id) assert.EqualValues(t, 1, findRes[0].Id)
assert.EqualValues(t, "xlw", findRes[0].Name) assert.EqualValues(t, "xlw", findRes[0].Name)
} }
func TestFindBytesVars(t *testing.T) {
type FindBytesVars struct {
Id int64
Bytes1 []byte
Bytes2 []byte
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(FindBytesVars))
_, err := testEngine.Insert([]FindBytesVars{
{
Bytes1: []byte("bytes1"),
Bytes2: []byte("bytes2"),
},
{
Bytes1: []byte("bytes1-1"),
Bytes2: []byte("bytes2-2"),
},
})
assert.NoError(t, err)
var gbv []FindBytesVars
err = testEngine.Find(&gbv)
assert.NoError(t, err)
assert.EqualValues(t, 2, len(gbv))
assert.EqualValues(t, []byte("bytes1"), gbv[0].Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv[0].Bytes2)
assert.EqualValues(t, []byte("bytes1-1"), gbv[1].Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv[1].Bytes2)
err = testEngine.Find(&gbv)
assert.NoError(t, err)
assert.EqualValues(t, 4, len(gbv))
assert.EqualValues(t, []byte("bytes1"), gbv[0].Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv[0].Bytes2)
assert.EqualValues(t, []byte("bytes1-1"), gbv[1].Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv[1].Bytes2)
assert.EqualValues(t, []byte("bytes1"), gbv[2].Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv[2].Bytes2)
assert.EqualValues(t, []byte("bytes1-1"), gbv[3].Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv[3].Bytes2)
}
func TestUpdateFindDate(t *testing.T) {
type TestUpdateFindDate struct {
Id int64
Name string
Tm time.Time `xorm:"DATE created"`
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(TestUpdateFindDate))
session := testEngine.NewSession()
defer session.Close()
tuf := TestUpdateFindDate{
Name: "test",
}
_, err := session.Insert(&tuf)
assert.NoError(t, err)
_, err = session.Where("`id` = ?", tuf.Id).Update(&TestUpdateFindDate{})
assert.EqualError(t, xorm.ErrNoColumnsTobeUpdated, err.Error())
var tufs []TestUpdateFindDate
err = session.Find(&tufs)
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tufs))
assert.EqualValues(t, tuf.Tm.Format("2006-01-02"), tufs[0].Tm.Format("2006-01-02"))
}

View File

@ -35,7 +35,7 @@ func TestGetVar(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(GetVar))) assert.NoError(t, testEngine.Sync(new(GetVar)))
var data = GetVar{ data := GetVar{
Msg: "hi", Msg: "hi",
Age: 28, Age: 28,
Money: 1.5, Money: 1.5,
@ -175,7 +175,7 @@ func TestGetVar(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, false, has) assert.Equal(t, false, has)
var valuesString = make(map[string]string) valuesString := make(map[string]string)
has, err = testEngine.Table("get_var").Get(&valuesString) has, err = testEngine.Table("get_var").Get(&valuesString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -187,7 +187,7 @@ func TestGetVar(t *testing.T) {
// for mymysql driver, interface{} will be []byte, so ignore it currently // for mymysql driver, interface{} will be []byte, so ignore it currently
if testEngine.DriverName() != "mymysql" { if testEngine.DriverName() != "mymysql" {
var valuesInter = make(map[string]interface{}) valuesInter := make(map[string]interface{})
has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter) has, err = testEngine.Table("get_var").Where("`id` = ?", 1).Select("*").Get(&valuesInter)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -198,7 +198,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"])) assert.Equal(t, "1.5", fmt.Sprintf("%v", valuesInter["money"]))
} }
var valuesSliceString = make([]string, 5) valuesSliceString := make([]string, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceString) has, err = testEngine.Table("get_var").Get(&valuesSliceString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -207,7 +207,7 @@ func TestGetVar(t *testing.T) {
assert.Equal(t, "28", valuesSliceString[2]) assert.Equal(t, "28", valuesSliceString[2])
assert.Equal(t, "1.5", valuesSliceString[3]) assert.Equal(t, "1.5", valuesSliceString[3])
var valuesSliceInter = make([]interface{}, 5) valuesSliceInter := make([]interface{}, 5)
has, err = testEngine.Table("get_var").Get(&valuesSliceInter) has, err = testEngine.Table("get_var").Get(&valuesSliceInter)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -317,7 +317,7 @@ func TestGetMap(t *testing.T) {
_, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (`is_man`) VALUES (NULL)", tableName)) _, err := testEngine.Exec(fmt.Sprintf("INSERT INTO %s (`is_man`) VALUES (NULL)", tableName))
assert.NoError(t, err) assert.NoError(t, err)
var valuesString = make(map[string]string) valuesString := make(map[string]string)
has, err := testEngine.Table("userinfo_map").Get(&valuesString) has, err := testEngine.Table("userinfo_map").Get(&valuesString)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, true, has) assert.Equal(t, true, has)
@ -336,7 +336,7 @@ func TestGetError(t *testing.T) {
assertSync(t, new(GetError)) assertSync(t, new(GetError))
var info = new(GetError) info := new(GetError)
has, err := testEngine.Get(&info) has, err := testEngine.Get(&info)
assert.False(t, has) assert.False(t, has)
assert.Error(t, err) assert.Error(t, err)
@ -456,7 +456,7 @@ func TestGetActionMapping(t *testing.T) {
}) })
assert.NoError(t, err) assert.NoError(t, err)
var valuesSlice = make([]string, 2) valuesSlice := make([]string, 2)
has, err := testEngine.Table(new(ActionMapping)). has, err := testEngine.Table(new(ActionMapping)).
Cols("script_id", "rollback_id"). Cols("script_id", "rollback_id").
ID("1").Get(&valuesSlice) ID("1").Get(&valuesSlice)
@ -483,7 +483,7 @@ func TestGetStructId(t *testing.T) {
Id int64 Id int64
} }
//var id int64 // var id int64
var maxid maxidst var maxid maxidst
sql := "select max(`id`) as id from " + testEngine.Quote(testEngine.TableName(&TestGetStruct{}, true)) sql := "select max(`id`) as id from " + testEngine.Quote(testEngine.TableName(&TestGetStruct{}, true))
has, err := testEngine.SQL(sql).Get(&maxid) has, err := testEngine.SQL(sql).Get(&maxid)
@ -693,7 +693,7 @@ func TestCustomTypes(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(TestCustomizeStruct)) assertSync(t, new(TestCustomizeStruct))
var s = TestCustomizeStruct{ s := TestCustomizeStruct{
Name: "test", Name: "test",
Age: 32, Age: 32,
} }
@ -763,7 +763,7 @@ func TestGetBigFloat(t *testing.T) {
assertSync(t, new(GetBigFloat)) assertSync(t, new(GetBigFloat))
{ {
var gf = GetBigFloat{ gf := GetBigFloat{
Money: big.NewFloat(999999.99), Money: big.NewFloat(999999.99),
} }
_, err := testEngine.Insert(&gf) _, err := testEngine.Insert(&gf)
@ -774,8 +774,8 @@ func TestGetBigFloat(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
} }
type GetBigFloat2 struct { type GetBigFloat2 struct {
@ -788,7 +788,7 @@ func TestGetBigFloat(t *testing.T) {
assertSync(t, new(GetBigFloat2)) assertSync(t, new(GetBigFloat2))
{ {
var gf2 = GetBigFloat2{ gf2 := GetBigFloat2{
Money: big.NewFloat(9999999.99), Money: big.NewFloat(9999999.99),
Money2: *big.NewFloat(99.99), Money2: *big.NewFloat(99.99),
} }
@ -800,8 +800,8 @@ func TestGetBigFloat(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String()) assert.True(t, m2.String() == gf2.Money.String(), "%v != %v", m2.String(), gf2.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
var gf3 GetBigFloat2 var gf3 GetBigFloat2
has, err = testEngine.ID(gf2.Id).Get(&gf3) has, err = testEngine.ID(gf2.Id).Get(&gf3)
@ -829,7 +829,7 @@ func TestGetDecimal(t *testing.T) {
assertSync(t, new(GetDecimal)) assertSync(t, new(GetDecimal))
{ {
var gf = GetDecimal{ gf := GetDecimal{
Money: decimal.NewFromFloat(999999.99), Money: decimal.NewFromFloat(999999.99),
} }
_, err := testEngine.Insert(&gf) _, err := testEngine.Insert(&gf)
@ -840,8 +840,8 @@ func TestGetDecimal(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
} }
type GetDecimal2 struct { type GetDecimal2 struct {
@ -854,7 +854,7 @@ func TestGetDecimal(t *testing.T) {
{ {
v := decimal.NewFromFloat(999999.99) v := decimal.NewFromFloat(999999.99)
var gf = GetDecimal2{ gf := GetDecimal2{
Money: &v, Money: &v,
} }
_, err := testEngine.Insert(&gf) _, err := testEngine.Insert(&gf)
@ -865,10 +865,11 @@ func TestGetDecimal(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has) assert.True(t, has)
assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String()) assert.True(t, m.String() == gf.Money.String(), "%v != %v", m.String(), gf.Money.String())
//fmt.Println(m.Cmp(gf.Money)) // fmt.Println(m.Cmp(gf.Money))
//assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String()) // assert.True(t, m.Cmp(gf.Money) == 0, "%v != %v", m.String(), gf.Money.String())
} }
} }
func TestGetTime(t *testing.T) { func TestGetTime(t *testing.T) {
type GetTimeStruct struct { type GetTimeStruct struct {
Id int64 Id int64
@ -878,7 +879,7 @@ func TestGetTime(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(GetTimeStruct)) assertSync(t, new(GetTimeStruct))
var gts = GetTimeStruct{ gts := GetTimeStruct{
CreateTime: time.Now().In(testEngine.GetTZLocation()), CreateTime: time.Now().In(testEngine.GetTZLocation()),
} }
_, err := testEngine.Insert(&gts) _, err := testEngine.Insert(&gts)
@ -976,3 +977,39 @@ func TestGetWithPrepare(t *testing.T) {
err = sess.Commit() err = sess.Commit()
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestGetBytesVars(t *testing.T) {
type GetBytesVars struct {
Id int64
Bytes1 []byte
Bytes2 []byte
}
assert.NoError(t, PrepareEngine())
assertSync(t, new(GetBytesVars))
_, err := testEngine.Insert([]GetBytesVars{
{
Bytes1: []byte("bytes1"),
Bytes2: []byte("bytes2"),
},
{
Bytes1: []byte("bytes1-1"),
Bytes2: []byte("bytes2-2"),
},
})
assert.NoError(t, err)
var gbv GetBytesVars
has, err := testEngine.Asc("id").Get(&gbv)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, []byte("bytes1"), gbv.Bytes1)
assert.EqualValues(t, []byte("bytes2"), gbv.Bytes2)
has, err = testEngine.Desc("id").NoAutoCondition().Get(&gbv)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, []byte("bytes1-1"), gbv.Bytes1)
assert.EqualValues(t, []byte("bytes2-2"), gbv.Bytes2)
}

View File

@ -6,6 +6,7 @@ package integrations
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
@ -248,7 +249,9 @@ func TestSyncTable3(t *testing.T) {
tableInfo, err := testEngine.TableInfo(new(SyncTable5)) tableInfo, err := testEngine.TableInfo(new(SyncTable5))
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("name")), testEngine.Dialect().SQLType(tables[0].GetColumn("name"))) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("name")), testEngine.Dialect().SQLType(tables[0].GetColumn("name")))
assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("text")), testEngine.Dialect().SQLType(tables[0].GetColumn("text"))) /* Engine.DBMetas() returns the size of the column from the database but Engine.TableInfo() might not be able to guess the column size.
For example using MySQL/MariaDB: when utf-8 charset is used, "`xorm:"TEXT(21846)`" creates a MEDIUMTEXT column not a TEXT column. */
assert.True(t, testEngine.Dialect().SQLType(tables[0].GetColumn("text")) == testEngine.Dialect().SQLType(tableInfo.GetColumn("text")) || strings.HasPrefix(testEngine.Dialect().SQLType(tables[0].GetColumn("text")), testEngine.Dialect().SQLType(tableInfo.GetColumn("text"))+"("))
assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("char")), testEngine.Dialect().SQLType(tables[0].GetColumn("char"))) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("char")), testEngine.Dialect().SQLType(tables[0].GetColumn("char")))
assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("ten_char")), testEngine.Dialect().SQLType(tables[0].GetColumn("ten_char"))) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("ten_char")), testEngine.Dialect().SQLType(tables[0].GetColumn("ten_char")))
assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("ten_var_char")), testEngine.Dialect().SQLType(tables[0].GetColumn("ten_var_char"))) assert.EqualValues(t, testEngine.Dialect().SQLType(tableInfo.GetColumn("ten_var_char")), testEngine.Dialect().SQLType(tables[0].GetColumn("ten_var_char")))

View File

@ -30,7 +30,7 @@ func TestArrayField(t *testing.T) {
assert.NoError(t, testEngine.Sync(new(ArrayStruct))) assert.NoError(t, testEngine.Sync(new(ArrayStruct)))
var as = ArrayStruct{ as := ArrayStruct{
Name: [20]byte{ Name: [20]byte{
96, 96, 96, 96, 96, 96, 96, 96, 96, 96,
96, 96, 96, 96, 96, 96, 96, 96, 96, 96,
@ -54,7 +54,7 @@ func TestArrayField(t *testing.T) {
assert.EqualValues(t, 1, len(arrs)) assert.EqualValues(t, 1, len(arrs))
assert.Equal(t, as.Name, arrs[0].Name) assert.Equal(t, as.Name, arrs[0].Name)
var newName = [20]byte{ newName := [20]byte{
90, 96, 96, 96, 96, 90, 96, 96, 96, 96,
96, 96, 96, 96, 96, 96, 96, 96, 96, 96,
96, 96, 96, 96, 96, 96, 96, 96, 96, 96,
@ -252,9 +252,11 @@ func TestConversion(t *testing.T) {
assert.Nil(t, c1.Nullable2) assert.Nil(t, c1.Nullable2)
} }
type MyInt int type (
type MyUInt uint MyInt int
type MyFloat float64 MyUInt uint
MyFloat float64
)
type MyStruct struct { type MyStruct struct {
Type MyInt Type MyInt
@ -273,7 +275,7 @@ type MyStruct struct {
UIA32 []uint32 UIA32 []uint32
UIA64 []uint64 UIA64 []uint64
UI uint UI uint
//C64 complex64 // C64 complex64
MSS map[string]string MSS map[string]string
} }
@ -304,6 +306,13 @@ func TestCustomType1(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.EqualValues(t, 1, cnt) assert.EqualValues(t, 1, cnt)
// since mssql don't support use text as index condition, we have to ignore below
// get and find tests
if testEngine.Dialect().URI().DBType == schemas.MSSQL {
t.Skip()
return
}
fmt.Println(i) fmt.Println(i)
i.NameArray = []string{} i.NameArray = []string{}
i.MSS = map[string]string{} i.MSS = map[string]string{}
@ -598,7 +607,7 @@ func TestMyArray(t *testing.T) {
assert.NoError(t, PrepareEngine()) assert.NoError(t, PrepareEngine())
assertSync(t, new(MyArrayStruct)) assertSync(t, new(MyArrayStruct))
var v = [20]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} v := [20]byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
_, err := testEngine.Insert(&MyArrayStruct{ _, err := testEngine.Insert(&MyArrayStruct{
Content: v, Content: v,
}) })

View File

@ -304,7 +304,7 @@ func (statement *Statement) needTableName() bool {
func (statement *Statement) colName(col *schemas.Column, tableName string) string { func (statement *Statement) colName(col *schemas.Column, tableName string) string {
if statement.needTableName() { if statement.needTableName() {
var nm = tableName nm := tableName
if len(statement.TableAlias) > 0 { if len(statement.TableAlias) > 0 {
nm = statement.TableAlias nm = statement.TableAlias
} }
@ -765,7 +765,7 @@ func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect
if len(table.PrimaryKeys) == 1 { if len(table.PrimaryKeys) == 1 {
pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
// fix non-int pk issues // fix non-int pk issues
//if pkField.Int() != 0 { // if pkField.Int() != 0 {
if pkField.IsValid() && !utils.IsZero(pkField.Interface()) { if pkField.IsValid() && !utils.IsZero(pkField.Interface()) {
return pkField.Interface(), true, nil return pkField.Interface(), true, nil
} }
@ -814,7 +814,8 @@ func (statement *Statement) asDBCond(fieldValue reflect.Value, fieldType reflect
func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeVersion bool, includeUpdated bool, includeNil bool,
includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool,
mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool,
) (builder.Cond, error) {
var conds []builder.Cond var conds []builder.Cond
for _, col := range table.Columns() { for _, col := range table.Columns() {
if !includeVersion && col.IsVersion { if !includeVersion && col.IsVersion {
@ -827,17 +828,13 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
continue continue
} }
if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text ||
col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
continue
}
if col.IsJSON { if col.IsJSON {
continue continue
} }
var colName string var colName string
if addedTableName { if addedTableName {
var nm = tableName nm := tableName
if len(aliasName) > 0 { if len(aliasName) > 0 {
nm = aliasName nm = aliasName
} }
@ -862,6 +859,15 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
continue continue
} }
if statement.dialect.URI().DBType == schemas.MSSQL && (col.SQLType.Name == schemas.Text ||
col.SQLType.IsBlob() || col.SQLType.Name == schemas.TimeStampz) {
if utils.IsValueZero(fieldValue) {
continue
}
return nil, fmt.Errorf("column %s is a TEXT type with data %#v which cannot be as compare condition", col.Name, fieldValue.Interface())
}
requiredField := useAllCols requiredField := useAllCols
if b, ok := getFlagForColumn(mustColumnMap, col); ok { if b, ok := getFlagForColumn(mustColumnMap, col); ok {
if b { if b {
@ -910,7 +916,7 @@ func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, i
func (statement *Statement) mergeConds(bean interface{}) error { func (statement *Statement) mergeConds(bean interface{}) error {
if !statement.NoAutoCondition && statement.RefTable != nil { if !statement.NoAutoCondition && statement.RefTable != nil {
var addedTableName = (len(statement.JoinStr) > 0) addedTableName := (len(statement.JoinStr) > 0)
autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName) autoCond, err := statement.BuildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
if err != nil { if err != nil {
return err return err
@ -948,7 +954,7 @@ func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string,
switch sqlOrArgs[0].(type) { switch sqlOrArgs[0].(type) {
case string: case string:
if len(sqlOrArgs) > 1 { if len(sqlOrArgs) > 1 {
var newArgs = make([]interface{}, 0, len(sqlOrArgs)-1) newArgs := make([]interface{}, 0, len(sqlOrArgs)-1)
for _, arg := range sqlOrArgs[1:] { for _, arg := range sqlOrArgs[1:] {
if v, ok := arg.(time.Time); ok { if v, ok := arg.(time.Time); ok {
newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05"))
@ -972,7 +978,7 @@ func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string,
} }
func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string { func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName bool) string {
var colnames = make([]string, len(cols)) colnames := make([]string, len(cols))
for i, col := range cols { for i, col := range cols {
if includeTableName { if includeTableName {
colnames[i] = statement.quote(statement.TableName()) + colnames[i] = statement.quote(statement.TableName()) +
@ -986,7 +992,7 @@ func (statement *Statement) joinColumns(cols []*schemas.Column, includeTableName
// CondDeleted returns the conditions whether a record is soft deleted. // CondDeleted returns the conditions whether a record is soft deleted.
func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond { func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
var colName = statement.quote(col.Name) colName := statement.quote(col.Name)
if statement.JoinStr != "" { if statement.JoinStr != "" {
var prefix string var prefix string
if statement.TableAlias != "" { if statement.TableAlias != "" {
@ -996,7 +1002,7 @@ func (statement *Statement) CondDeleted(col *schemas.Column) builder.Cond {
} }
colName = statement.quote(prefix) + "." + statement.quote(col.Name) colName = statement.quote(prefix) + "." + statement.quote(col.Name)
} }
var cond = builder.NewCond() cond := builder.NewCond()
if col.SQLType.IsNumeric() { if col.SQLType.IsNumeric() {
cond = builder.Eq{colName: 0} cond = builder.Eq{colName: 0}
} else { } else {

24
scan.go
View File

@ -22,7 +22,7 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
switch t := bean.(type) { switch t := bean.(type) {
case *interface{}: case *interface{}:
return t, false, nil return t, false, nil
case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes: case *sql.NullInt64, *sql.NullBool, *sql.NullFloat64, *sql.NullString, *sql.RawBytes, *[]byte:
return t, false, nil return t, false, nil
case *time.Time: case *time.Time:
return &sql.NullString{}, true, nil return &sql.NullString{}, true, nil
@ -67,7 +67,7 @@ func genScanResultsByBeanNullable(bean interface{}) (interface{}, bool, error) {
case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8: case reflect.Uint32, reflect.Uint, reflect.Uint16, reflect.Uint8:
return &convert.NullUint32{}, true, nil return &convert.NullUint32{}, true, nil
default: default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean) return nil, false, fmt.Errorf("genScanResultsByBeanNullable: unsupported type: %#v", bean)
} }
} }
@ -125,12 +125,12 @@ func genScanResultsByBean(bean interface{}) (interface{}, bool, error) {
case reflect.Float64: case reflect.Float64:
return new(float64), true, nil return new(float64), true, nil
default: default:
return nil, false, fmt.Errorf("unsupported type: %#v", bean) return nil, false, fmt.Errorf("genScanResultsByBean: unsupported type: %#v", bean)
} }
} }
func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) {
var scanResults = make([]interface{}, len(types)) scanResults := make([]interface{}, len(types))
for i := 0; i < len(types); i++ { for i := 0; i < len(types); i++ {
var s sql.NullString var s sql.NullString
scanResults[i] = &s scanResults[i] = &s
@ -144,8 +144,8 @@ func (engine *Engine) scanStringInterface(rows *core.Rows, fields []string, type
// scan is a wrap of driver.Scan but will automatically change the input values according requirements // scan is a wrap of driver.Scan but will automatically change the input values according requirements
func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error { func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types)) scanResults := make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types)) replaces := make([]bool, 0, len(types))
var err error var err error
for _, v := range vv { for _, v := range vv {
var replaced bool var replaced bool
@ -194,7 +194,7 @@ func (engine *Engine) scan(rows *core.Rows, fields []string, types []*sql.Column
} }
func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) { func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*sql.ColumnType) ([]interface{}, error) {
var scanResultContainers = make([]interface{}, len(types)) scanResultContainers := make([]interface{}, len(types))
for i := 0; i < len(types); i++ { for i := 0; i < len(types); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil { if err != nil {
@ -212,8 +212,8 @@ func (engine *Engine) scanInterfaces(rows *core.Rows, fields []string, types []*
// row -> map[string]interface{} // row -> map[string]interface{}
func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) { func (engine *Engine) row2mapInterface(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]interface{}, error) {
var resultsMap = make(map[string]interface{}, len(fields)) resultsMap := make(map[string]interface{}, len(fields))
var scanResultContainers = make([]interface{}, len(fields)) scanResultContainers := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName()) scanResult, err := engine.driver.GenScanResult(types[i].DatabaseTypeName())
if err != nil { if err != nil {
@ -277,7 +277,7 @@ func (engine *Engine) ScanInterfaceMaps(rows *core.Rows) (resultsSlice []map[str
// row -> map[string]string // row -> map[string]string
func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) { func (engine *Engine) row2mapStr(rows *core.Rows, types []*sql.ColumnType, fields []string) (map[string]string, error) {
var scanResults = make([]interface{}, len(fields)) scanResults := make([]interface{}, len(fields))
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
var s sql.NullString var s sql.NullString
scanResults[i] = &s scanResults[i] = &s
@ -353,7 +353,7 @@ func (engine *Engine) ScanStringMaps(rows *core.Rows) (resultsSlice []map[string
// row -> map[string][]byte // row -> map[string][]byte
func convertMapStr2Bytes(m map[string]string) map[string][]byte { func convertMapStr2Bytes(m map[string]string) map[string][]byte {
var r = make(map[string][]byte, len(m)) r := make(map[string][]byte, len(m))
for k, v := range m { for k, v := range m {
r[k] = []byte(v) r[k] = []byte(v)
} }
@ -392,7 +392,7 @@ func (engine *Engine) row2sliceStr(rows *core.Rows, types []*sql.ColumnType, fie
return nil, err return nil, err
} }
var results = make([]string, 0, len(fields)) results := make([]string, 0, len(fields))
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
results = append(results, scanResults[i].(*sql.NullString).String) results = append(results, scanResults[i].(*sql.NullString).String)
} }

View File

@ -79,7 +79,7 @@ type Session struct {
afterClosures []func(interface{}) afterClosures []func(interface{})
afterProcessors []executedProcessor afterProcessors []executedProcessor
stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) stmtCache map[uint32]*core.Stmt // key: hash.Hash32 of (queryStr, len(queryStr))
txStmtCache map[uint32]*core.Stmt // for tx statement txStmtCache map[uint32]*core.Stmt // for tx statement
lastSQL string lastSQL string
@ -314,7 +314,7 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
// MustLogSQL means record SQL or not and don't follow engine's setting // MustLogSQL means record SQL or not and don't follow engine's setting
func (session *Session) MustLogSQL(logs ...bool) *Session { func (session *Session) MustLogSQL(logs ...bool) *Session {
var showSQL = true showSQL := true
if len(logs) > 0 { if len(logs) > 0 {
showSQL = logs[0] showSQL = logs[0]
} }
@ -396,7 +396,7 @@ func (session *Session) doPrepareTx(sqlStr string) (stmt *core.Stmt, err error)
} }
func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) { func getField(dataStruct *reflect.Value, table *schemas.Table, colName string, idx int) (*schemas.Column, *reflect.Value, error) {
var col = table.GetColumnIdx(colName, idx) col := table.GetColumnIdx(colName, idx)
if col == nil { if col == nil {
return nil, nil, ErrFieldIsNotExist{colName, table.Name} return nil, nil, ErrFieldIsNotExist{colName, table.Name}
} }
@ -420,9 +420,10 @@ type Cell *interface{}
func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType, func (session *Session) rows2Beans(rows *core.Rows, fields []string, types []*sql.ColumnType,
table *schemas.Table, newElemFunc func([]string) reflect.Value, table *schemas.Table, newElemFunc func([]string) reflect.Value,
sliceValueSetFunc func(*reflect.Value, schemas.PK) error) error { sliceValueSetFunc func(*reflect.Value, schemas.PK) error,
) error {
for rows.Next() { for rows.Next() {
var newValue = newElemFunc(fields) newValue := newElemFunc(fields)
bean := newValue.Interface() bean := newValue.Interface()
dataStruct := newValue.Elem() dataStruct := newValue.Elem()
@ -533,8 +534,11 @@ func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv) return nil, fmt.Errorf("unsupported primary key type: %v, %v", tp, vv)
} }
var uint8ZeroValue = reflect.ValueOf(uint8(0))
func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value, func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflect.Value,
scanResult interface{}, table *schemas.Table) error { scanResult interface{}, table *schemas.Table,
) error {
v, ok := scanResult.(*interface{}) v, ok := scanResult.(*interface{})
if ok { if ok {
scanResult = *v scanResult = *v
@ -596,7 +600,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
return nil return nil
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
return setJSON(fieldValue, fieldType, scanResult) return setJSON(fieldValue, fieldType, scanResult)
case reflect.Slice, reflect.Array: case reflect.Slice:
bs, ok := convert.AsBytes(scanResult) bs, ok := convert.AsBytes(scanResult)
if ok && fieldType.Elem().Kind() == reflect.Uint8 { if ok && fieldType.Elem().Kind() == reflect.Uint8 {
if col.SQLType.IsText() { if col.SQLType.IsText() {
@ -607,15 +611,29 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
} }
fieldValue.Set(x.Elem()) fieldValue.Set(x.Elem())
} else { } else {
if fieldValue.Len() > 0 { fieldValue.Set(reflect.ValueOf(bs))
}
return nil
}
case reflect.Array:
bs, ok := convert.AsBytes(scanResult)
if ok && fieldType.Elem().Kind() == reflect.Uint8 {
if col.SQLType.IsText() {
x := reflect.New(fieldType)
err := json.DefaultJSONHandler.Unmarshal(bs, x.Interface())
if err != nil {
return err
}
fieldValue.Set(x.Elem())
} else {
if fieldValue.Len() < vv.Len() {
return fmt.Errorf("Set field %s[Array] failed because of data too long", col.Name)
}
for i := 0; i < fieldValue.Len(); i++ { for i := 0; i < fieldValue.Len(); i++ {
if i < vv.Len() { if i < vv.Len() {
fieldValue.Index(i).Set(vv.Index(i)) fieldValue.Index(i).Set(vv.Index(i))
}
}
} else { } else {
for i := 0; i < vv.Len(); i++ { fieldValue.Index(i).Set(uint8ZeroValue)
fieldValue.Set(reflect.Append(*fieldValue, vv.Index(i)))
} }
} }
} }
@ -659,7 +677,7 @@ func (session *Session) convertBeanField(col *schemas.Column, fieldValue *reflec
if len(table.PrimaryKeys) != 1 { if len(table.PrimaryKeys) != 1 {
return errors.New("unsupported non or composited primary key cascade") return errors.New("unsupported non or composited primary key cascade")
} }
var pk = make(schemas.PK, len(table.PrimaryKeys)) pk := make(schemas.PK, len(table.PrimaryKeys))
pk[0], err = asKind(vv, reflect.TypeOf(scanResult)) pk[0], err = asKind(vv, reflect.TypeOf(scanResult))
if err != nil { if err != nil {
return err return err
@ -694,11 +712,11 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, b
buildAfterProcessors(session, bean) buildAfterProcessors(session, bean)
var tempMap = make(map[string]int) tempMap := make(map[string]int)
var pk schemas.PK var pk schemas.PK
for i, colName := range fields { for i, colName := range fields {
var idx int var idx int
var lKey = strings.ToLower(colName) lKey := strings.ToLower(colName)
var ok bool var ok bool
if idx, ok = tempMap[lKey]; !ok { if idx, ok = tempMap[lKey]; !ok {