From 725e720559db90aad7d04d7bc2b732d02eb90603 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 27 Sep 2019 17:19:27 +0800 Subject: [PATCH] fix default value tags --- dialect_sqlite3.go | 30 ++++++++++++++++++- dialect_sqlite3_test.go | 35 ++++++++++++++++++++++ helpers_test.go | 19 ------------ tag_test.go | 64 +++++++++++++++++++++++++++++++++++++++-- 4 files changed, 125 insertions(+), 23 deletions(-) create mode 100644 dialect_sqlite3_test.go diff --git a/dialect_sqlite3.go b/dialect_sqlite3.go index e4bac00a..d1852e9b 100644 --- a/dialect_sqlite3.go +++ b/dialect_sqlite3.go @@ -270,6 +270,34 @@ func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { return false, nil } +// splitColStr splits a sqlite col strings as fields +func splitColStr(colStr string) []string { + colStr = strings.TrimSpace(colStr) + var results = make([]string, 0, 10) + var lastIdx int + var hasC, hasQuote bool + for i, c := range colStr { + if c == ' ' && !hasQuote { + if hasC { + results = append(results, colStr[lastIdx:i]) + hasC = false + } + } else { + if c == '\'' { + hasQuote = !hasQuote + } + if !hasC { + lastIdx = i + } + hasC = true + if i == len(colStr)-1 { + results = append(results, colStr[lastIdx:i+1]) + } + } + } + return results +} + func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" @@ -315,7 +343,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu continue } - fields := strings.Fields(strings.TrimSpace(colStr)) + fields := splitColStr(colStr) col := new(core.Column) col.Indexes = make(map[string]int) col.Nullable = true diff --git a/dialect_sqlite3_test.go b/dialect_sqlite3_test.go new file mode 100644 index 00000000..a2036159 --- /dev/null +++ b/dialect_sqlite3_test.go @@ -0,0 +1,35 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xorm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSplitColStr(t *testing.T) { + var kases = []struct { + colStr string + fields []string + }{ + { + colStr: "`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL", + fields: []string{ + "`id`", "INTEGER", "PRIMARY", "KEY", "AUTOINCREMENT", "NOT", "NULL", + }, + }, + { + colStr: "`created` DATETIME DEFAULT '2006-01-02 15:04:05' NULL", + fields: []string{ + "`created`", "DATETIME", "DEFAULT", "'2006-01-02 15:04:05'", "NULL", + }, + }, + } + + for _, kase := range kases { + assert.EqualValues(t, kase.fields, splitColStr(kase.colStr)) + } +} diff --git a/helpers_test.go b/helpers_test.go index 7e317126..caf7b9f0 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -10,25 +10,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestSplitTag(t *testing.T) { - var cases = []struct { - tag string - tags []string - }{ - {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, - {"TEXT", []string{"TEXT"}}, - {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, - {"json binary", []string{"json", "binary"}}, - } - - for _, kase := range cases { - tags := splitTag(kase.tag) - if !sliceEq(tags, kase.tags) { - t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) - } - } -} - func TestEraseAny(t *testing.T) { raw := "SELECT * FROM `table`.[table_name]" assert.EqualValues(t, raw, eraseAny(raw)) diff --git a/tag_test.go b/tag_test.go index 74cc5eca..dde4f3d4 100644 --- a/tag_test.go +++ b/tag_test.go @@ -368,19 +368,24 @@ func TestTagDefault4(t *testing.T) { func TestTagDefault5(t *testing.T) { assert.NoError(t, prepareEngine()) - type DefaultStruct4 struct { + type DefaultStruct5 struct { Id int64 Created time.Time `xorm:"default('2006-01-02 15:04:05')"` } - assertSync(t, new(DefaultStruct4)) + assertSync(t, new(DefaultStruct5)) + table := testEngine.TableInfo(new(DefaultStruct5)) + createdCol := table.GetColumn("created") + assert.NotNil(t, createdCol) + assert.EqualValues(t, "'2006-01-02 15:04:05'", createdCol.Default) + assert.False(t, createdCol.DefaultIsEmpty) tables, err := testEngine.DBMetas() assert.NoError(t, err) var defaultVal string var isDefaultExist bool - tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct4") + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct5") for _, table := range tables { if table.Name == tableName { col := table.GetColumn("created") @@ -394,6 +399,40 @@ func TestTagDefault5(t *testing.T) { assert.EqualValues(t, "'2006-01-02 15:04:05'", defaultVal) } +func TestTagDefault6(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type DefaultStruct6 struct { + Id int64 + IsMan bool `xorm:"default(true)"` + } + + assertSync(t, new(DefaultStruct6)) + + tables, err := testEngine.DBMetas() + assert.NoError(t, err) + + var defaultVal string + var isDefaultExist bool + tableName := testEngine.GetColumnMapper().Obj2Table("DefaultStruct6") + for _, table := range tables { + if table.Name == tableName { + col := table.GetColumn("is_man") + assert.NotNil(t, col) + defaultVal = col.Default + isDefaultExist = !col.DefaultIsEmpty + break + } + } + assert.True(t, isDefaultExist) + if defaultVal == "1" { + defaultVal = "true" + } else if defaultVal == "0" { + defaultVal = "false" + } + assert.EqualValues(t, "true", defaultVal) +} + func TestTagsDirection(t *testing.T) { assert.NoError(t, prepareEngine()) @@ -489,3 +528,22 @@ func TestTagTime(t *testing.T) { assert.EqualValues(t, s.Created.UTC().Format("2006-01-02 15:04:05"), strings.Replace(strings.Replace(tm, "T", " ", -1), "Z", "", -1)) } + +func TestSplitTag(t *testing.T) { + var cases = []struct { + tag string + tags []string + }{ + {"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}}, + {"TEXT", []string{"TEXT"}}, + {"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}}, + {"json binary", []string{"json", "binary"}}, + } + + for _, kase := range cases { + tags := splitTag(kase.tag) + if !sliceEq(tags, kase.tags) { + t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags) + } + } +}