diff --git a/convert/conversion.go b/convert/conversion.go index b69e345c..5577e863 100644 --- a/convert/conversion.go +++ b/convert/conversion.go @@ -16,11 +16,21 @@ import ( "time" ) +// ConversionFrom is an inteface to allow retrieve data from database +type ConversionFrom interface { + FromDB([]byte) error +} + +// ConversionTo is an interface to allow store data to database +type ConversionTo interface { + ToDB() ([]byte, error) +} + // Conversion is an interface. A type implements Conversion will according // the custom method to fill into database and retrieve from database. type Conversion interface { - FromDB([]byte) error - ToDB() ([]byte, error) + ConversionFrom + ConversionTo } // ErrNilPtr represents an error diff --git a/internal/statements/statement.go b/internal/statements/statement.go index 7ad735f5..c075ec54 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -644,6 +644,23 @@ func (statement *Statement) convertSQLOrArgs(sqlOrArgs ...interface{}) (string, newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) } else if v, ok := arg.(*time.Time); ok && v != nil { newArgs = append(newArgs, v.In(statement.defaultTimeZone).Format("2006-01-02 15:04:05")) + } else if v, ok := arg.(convert.ConversionTo); ok { + r, err := v.ToDB() + if err != nil { + return "", nil, err + } + if r != nil { + // for nvarchar column on mssql, bytes have to be converted as ucs-2 external of driver + // for binary column, a string will be converted as bytes directly. So we have to + // convert bytes as string + if statement.dialect.URI().DBType == schemas.MSSQL { + newArgs = append(newArgs, string(r)) + } else { + newArgs = append(newArgs, r) + } + } else { + newArgs = append(newArgs, nil) + } } else { newArgs = append(newArgs, arg) } diff --git a/tests/session_raw_test.go b/tests/session_raw_test.go index e6987c41..569d7bed 100644 --- a/tests/session_raw_test.go +++ b/tests/session_raw_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "xorm.io/xorm/convert" + "github.com/stretchr/testify/assert" ) @@ -65,3 +67,48 @@ func TestExecTime(t *testing.T) { assert.True(t, has) assert.EqualValues(t, now.In(testEngine.GetTZLocation()).Format("2006-01-02 15:04:05"), uet.Created.Format("2006-01-02 15:04:05")) } + +type ConversionData struct { + MyData string +} + +var _ convert.Conversion = new(ConversionData) + +func (c ConversionData) ToDB() ([]byte, error) { + return []byte(c.MyData), nil +} + +func (c *ConversionData) FromDB(bs []byte) error { + if bs != nil { + c.MyData = string(bs) + } + return nil +} + +func TestExecCustomTypes(t *testing.T) { + assert.NoError(t, PrepareEngine()) + + type UserinfoExec struct { + Uid int + Name string + Data string + } + + assert.NoError(t, testEngine.Sync2(new(UserinfoExec))) + + res, err := testEngine.Exec("INSERT INTO "+testEngine.TableName("`userinfo_exec`", true)+" (uid, name,data) VALUES (?, ?, ?)", + 1, "user", ConversionData{"data"}) + assert.NoError(t, err) + cnt, err := res.RowsAffected() + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + results, err := testEngine.QueryString("select * from " + testEngine.TableName("userinfo_exec", true)) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + id, err := strconv.Atoi(results[0]["uid"]) + assert.NoError(t, err) + assert.EqualValues(t, 1, id) + assert.Equal(t, "user", results[0]["name"]) + assert.EqualValues(t, "data", results[0]["data"]) +}