From 18362d23402a57bc3089442a0d87b6ed5d35ac3d Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 20 Jul 2017 20:07:54 +0800 Subject: [PATCH] support different text type with JSON type --- dialect_mssql.go | 2 +- dialect_mysql.go | 2 -- dialect_oracle.go | 2 +- dialect_sqlite3.go | 2 +- engine.go | 1 - engine_cond.go | 4 ++-- session.go | 4 ++-- session_convert.go | 2 +- session_delete_test.go | 7 +++++++ session_exist.go | 12 +++++++++-- statement.go | 2 +- tag.go | 17 +++++++++++++++- tag_test.go | 45 ++++++++++++++++++++++++++++++++++++++++++ test_mssql.sh | 2 +- 14 files changed, 88 insertions(+), 16 deletions(-) diff --git a/dialect_mssql.go b/dialect_mssql.go index 6d2291dc..272d4ab2 100644 --- a/dialect_mssql.go +++ b/dialect_mssql.go @@ -243,7 +243,7 @@ func (db *mssql) SqlType(c *core.Column) string { c.Length = 7 case core.MediumInt: res = core.Int - case core.Text, core.MediumText, core.TinyText, core.LongText, core.Json: + case core.Text, core.MediumText, core.TinyText, core.LongText: res = core.Varchar + "(MAX)" case core.Double: res = core.Real diff --git a/dialect_mysql.go b/dialect_mysql.go index 99100b23..c54903b8 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -222,8 +222,6 @@ func (db *mysql) SqlType(c *core.Column) string { case core.Uuid: res = core.Varchar c.Length = 40 - case core.Json: - res = core.Text default: res = t } diff --git a/dialect_oracle.go b/dialect_oracle.go index ac0081b3..17c4adbe 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -519,7 +519,7 @@ func (db *oracle) SqlType(c *core.Column) string { res = "TIMESTAMP WITH TIME ZONE" case core.Float, core.Double, core.Numeric, core.Decimal: res = "NUMBER" - case core.Text, core.MediumText, core.LongText, core.Json: + case core.Text, core.MediumText, core.LongText: res = "CLOB" case core.Char, core.Varchar, core.TinyText: res = "VARCHAR2" diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index a55b1615..f9bc714d 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -165,7 +165,7 @@ func (db *sqlite3) SqlType(c *core.Column) string { case core.TimeStampz: return core.Text case core.Char, core.Varchar, core.NVarchar, core.TinyText, - core.Text, core.MediumText, core.LongText, core.Json: + core.Text, core.MediumText, core.LongText: return core.Text case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt: return core.Integer diff --git a/engine.go b/engine.go index 84e9206d..f0c8f5ad 100644 --- a/engine.go +++ b/engine.go @@ -1023,7 +1023,6 @@ func (engine *Engine) mapType(v reflect.Value) (*core.Table, error) { } table.AddColumn(col) - } // end for if idFieldColName != "" && len(table.PrimaryKeys) == 0 { diff --git a/engine_cond.go b/engine_cond.go index 6c8e3879..5ad5fbfb 100644 --- a/engine_cond.go +++ b/engine_cond.go @@ -34,7 +34,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && (col.SQLType.Name == core.Text || col.SQLType.IsBlob() || col.SQLType.Name == core.TimeStampz) { continue } - if col.SQLType.IsJson() { + if col.IsJSON || col.SQLType.IsJson() { continue } @@ -142,7 +142,7 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{}, continue } } else { - if col.SQLType.IsJson() { + if col.IsJSON || col.SQLType.IsJson() { if col.SQLType.IsText() { bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { diff --git a/session.go b/session.go index 76d7cb28..7410330e 100644 --- a/session.go +++ b/session.go @@ -408,7 +408,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f fieldType := fieldValue.Type() hasAssigned := false - if col.SQLType.IsJson() { + if col.IsJSON || col.SQLType.IsJson() { var bs []byte if rawValueType.Kind() == reflect.String { bs = []byte(vv.String()) @@ -584,7 +584,7 @@ func (session *Session) slice2Bean(scanResults []interface{}, fields []string, f session.Engine.logger.Error("sql.Sanner error:", err.Error()) hasAssigned = false } - } else if col.SQLType.IsJson() { + } else if col.IsJSON || col.SQLType.IsJson() { if rawValueType.Kind() == reflect.String { hasAssigned = true x := reflect.New(fieldType) diff --git a/session_convert.go b/session_convert.go index 931d1dc0..8da415e5 100644 --- a/session_convert.go +++ b/session_convert.go @@ -591,7 +591,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return tf, nil } - if !col.SQLType.IsJson() { + if !(col.IsJSON || col.SQLType.IsJson()) { // !! 增加支持driver.Valuer接口的结构,如sql.NullString if v, ok := fieldValue.Interface().(driver.Valuer); ok { return v.Value() diff --git a/session_delete_test.go b/session_delete_test.go index 27e61321..94b5f43f 100644 --- a/session_delete_test.go +++ b/session_delete_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/go-xorm/core" "github.com/stretchr/testify/assert" ) @@ -21,6 +22,12 @@ func TestDelete(t *testing.T) { assert.NoError(t, testEngine.Sync2(new(UserinfoDelete))) + var err error + if testEngine.dialect.DBType() == core.MSSQL { + _, err = testEngine.Exec("SET IDENTITY_INSERT " + testEngine.TableMapper.Obj2Table("UserinfoDelete") + " ON") + assert.NoError(t, err) + } + user := UserinfoDelete{Uid: 1} cnt, err := testEngine.Insert(&user) assert.NoError(t, err) diff --git a/session_exist.go b/session_exist.go index 6f895c1e..ed450648 100644 --- a/session_exist.go +++ b/session_exist.go @@ -37,10 +37,18 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { return false, err } - sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) + if session.Engine.dialect.DBType() == core.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s WHERE %s", tableName, condSQL) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s WHERE %s LIMIT 1", tableName, condSQL) + } args = condArgs } else { - sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) + if session.Engine.dialect.DBType() == core.MSSQL { + sqlStr = fmt.Sprintf("SELECT TOP 1 * FROM %s", tableName) + } else { + sqlStr = fmt.Sprintf("SELECT * FROM %s LIMIT 1", tableName) + } args = []interface{}{} } } else { diff --git a/statement.go b/statement.go index 6e360bb3..e022a9b5 100644 --- a/statement.go +++ b/statement.go @@ -380,7 +380,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { val, _ = nulType.Value() } else { - if !col.SQLType.IsJson() { + if !(col.IsJSON || col.SQLType.IsJson()) { engine.autoMapType(fieldValue) if table, ok := engine.Tables[fieldValue.Type()]; ok { if len(table.PrimaryKeys) == 1 { diff --git a/tag.go b/tag.go index e1c821fb..a8ac6df3 100644 --- a/tag.go +++ b/tag.go @@ -55,12 +55,16 @@ var ( "CACHE": CacheTagHandler, "NOCACHE": NoCacheTagHandler, "COMMENT": CommentTagHandler, + "JSON": JSONTagHandler, } ) func init() { for k := range core.SqlTypes { - defaultTagHandlers[k] = SQLTypeTagHandler + // don't overwrite + if _, ok := defaultTagHandlers[k]; !ok { + defaultTagHandlers[k] = SQLTypeTagHandler + } } } @@ -241,8 +245,19 @@ func SQLTypeTagHandler(ctx *tagContext) error { return nil } +// JSONTagHandler handle json tag +func JSONTagHandler(ctx *tagContext) error { + fmt.Println("fdsfafadfs") + ctx.col.IsJSON = true + if len(ctx.params) == 0 { + ctx.col.SQLType = core.SQLType{Name: core.Text} + } + return nil +} + // ExtendsTagHandler describes extends tag handler func ExtendsTagHandler(ctx *tagContext) error { + ctx.ignoreNext = true var fieldValue = ctx.fieldValue switch fieldValue.Kind() { case reflect.Ptr: diff --git a/tag_test.go b/tag_test.go index 4fedd0e7..950b87df 100644 --- a/tag_test.go +++ b/tag_test.go @@ -270,3 +270,48 @@ func TestTagComment(t *testing.T) { assert.EqualValues(t, 1, len(tables[0].Columns())) assert.EqualValues(t, "主键", tables[0].Columns()[0].Comment) } + +func TestTagJSON(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type TestTagJSON1 struct { + Stings []string `xorm:"json"` + } + + assert.NoError(t, testEngine.Sync2(new(TestTagJSON1))) + + table := testEngine.TableInfo(new(TestTagJSON1)) + assert.NotNil(t, table) + assert.EqualValues(t, 1, len(table.Columns())) + assert.True(t, table.Columns()[0].IsJSON) + assert.EqualValues(t, "TEXT", table.Columns()[0].SQLType.Name) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + if testEngine.dialect.DBType() != core.MSSQL { + assert.EqualValues(t, "TEXT", tables[0].Columns()[0].SQLType.Name) + } + assert.NoError(t, testEngine.DropTables(new(TestTagJSON1))) + + type TestTagJSON2 struct { + Stings []string `xorm:"json MediumText"` + } + + assert.NoError(t, testEngine.Sync2(new(TestTagJSON2))) + + table = testEngine.TableInfo(new(TestTagJSON2)) + assert.NotNil(t, table) + assert.EqualValues(t, 1, len(table.Columns())) + assert.True(t, table.Columns()[0].IsJSON) + assert.EqualValues(t, "MEDIUMTEXT", table.Columns()[0].SQLType.Name) + + tables, err = testEngine.DBMetas() + assert.NoError(t, err) + assert.EqualValues(t, 1, len(tables)) + assert.EqualValues(t, 1, len(tables[0].Columns())) + if testEngine.dialect.DBType() == core.MYSQL { + assert.EqualValues(t, "MEDIUMTEXT", tables[0].Columns()[0].SQLType.Name) + } +} diff --git a/test_mssql.sh b/test_mssql.sh index 6f9cf729..203e1aea 100755 --- a/test_mssql.sh +++ b/test_mssql.sh @@ -1 +1 @@ -go test -db=mssql -conn_str="server=192.168.1.58;user id=sa;password=123456;database=xorm_test" \ No newline at end of file +go test -db=mssql -conn_str="server=192.168.1.158;user id=sa;password=123456;database=xorm_test" \ No newline at end of file