Merge branch 'master' into lunny/default
This commit is contained in:
commit
db88ed3016
|
@ -249,11 +249,11 @@ volumes:
|
|||
services:
|
||||
- name: mssql
|
||||
pull: always
|
||||
image: microsoft/mssql-server-linux:latest
|
||||
image: mcr.microsoft.com/mssql/server:latest
|
||||
environment:
|
||||
ACCEPT_EULA: Y
|
||||
SA_PASSWORD: yourStrong(!)Password
|
||||
MSSQL_PID: Developer
|
||||
MSSQL_PID: Standard
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
|
|
|
@ -37,3 +37,4 @@ test.db.sql
|
|||
test.db
|
||||
integrations/*.sql
|
||||
integrations/test_sqlite*
|
||||
cover.out
|
14
.revive.toml
14
.revive.toml
|
@ -8,20 +8,22 @@ warningCode = 1
|
|||
[rule.context-as-argument]
|
||||
[rule.context-keys-type]
|
||||
[rule.dot-imports]
|
||||
[rule.empty-lines]
|
||||
[rule.errorf]
|
||||
[rule.error-return]
|
||||
[rule.error-strings]
|
||||
[rule.error-naming]
|
||||
[rule.exported]
|
||||
[rule.if-return]
|
||||
[rule.increment-decrement]
|
||||
[rule.var-naming]
|
||||
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]
|
||||
[rule.var-declaration]
|
||||
[rule.indent-error-flow]
|
||||
[rule.package-comments]
|
||||
[rule.range]
|
||||
[rule.receiver-naming]
|
||||
[rule.struct-tag]
|
||||
[rule.time-naming]
|
||||
[rule.unexported-return]
|
||||
[rule.indent-error-flow]
|
||||
[rule.errorf]
|
||||
[rule.struct-tag]
|
||||
[rule.unnecessary-stmt]
|
||||
[rule.var-declaration]
|
||||
[rule.var-naming]
|
||||
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]
|
56
convert.go
56
convert.go
|
@ -175,7 +175,10 @@ func convertAssign(dest, src interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
dpv := reflect.ValueOf(dest)
|
||||
return convertAssignV(reflect.ValueOf(dest), src)
|
||||
}
|
||||
|
||||
func convertAssignV(dpv reflect.Value, src interface{}) error {
|
||||
if dpv.Kind() != reflect.Ptr {
|
||||
return errors.New("destination not a pointer")
|
||||
}
|
||||
|
@ -183,9 +186,7 @@ func convertAssign(dest, src interface{}) error {
|
|||
return errNilPtr
|
||||
}
|
||||
|
||||
if !sv.IsValid() {
|
||||
sv = reflect.ValueOf(src)
|
||||
}
|
||||
var sv = reflect.ValueOf(src)
|
||||
|
||||
dv := reflect.Indirect(dpv)
|
||||
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
|
||||
|
@ -244,7 +245,7 @@ func convertAssign(dest, src interface{}) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dpv.Interface())
|
||||
}
|
||||
|
||||
func asKind(vv reflect.Value, tp reflect.Type) (interface{}, error) {
|
||||
|
@ -375,48 +376,3 @@ func str2PK(s string, tp reflect.Type) (interface{}, error) {
|
|||
}
|
||||
return v.Interface(), nil
|
||||
}
|
||||
|
||||
func int64ToIntValue(id int64, tp reflect.Type) reflect.Value {
|
||||
var v interface{}
|
||||
kind := tp.Kind()
|
||||
|
||||
if kind == reflect.Ptr {
|
||||
kind = tp.Elem().Kind()
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Int16:
|
||||
temp := int16(id)
|
||||
v = &temp
|
||||
case reflect.Int32:
|
||||
temp := int32(id)
|
||||
v = &temp
|
||||
case reflect.Int:
|
||||
temp := int(id)
|
||||
v = &temp
|
||||
case reflect.Int64:
|
||||
temp := id
|
||||
v = &temp
|
||||
case reflect.Uint16:
|
||||
temp := uint16(id)
|
||||
v = &temp
|
||||
case reflect.Uint32:
|
||||
temp := uint32(id)
|
||||
v = &temp
|
||||
case reflect.Uint64:
|
||||
temp := uint64(id)
|
||||
v = &temp
|
||||
case reflect.Uint:
|
||||
temp := uint(id)
|
||||
v = &temp
|
||||
}
|
||||
|
||||
if tp.Kind() == reflect.Ptr {
|
||||
return reflect.ValueOf(v).Convert(tp)
|
||||
}
|
||||
return reflect.ValueOf(v).Elem().Convert(tp)
|
||||
}
|
||||
|
||||
func int64ToInt(id int64, tp reflect.Type) interface{} {
|
||||
return int64ToIntValue(id, tp).Interface()
|
||||
}
|
||||
|
|
|
@ -23,13 +23,45 @@ type SeqFilter struct {
|
|||
func convertQuestionMark(sql, prefix string, start int) string {
|
||||
var buf strings.Builder
|
||||
var beginSingleQuote bool
|
||||
var isLineComment bool
|
||||
var isComment bool
|
||||
var isMaybeLineComment bool
|
||||
var isMaybeComment bool
|
||||
var isMaybeCommentEnd bool
|
||||
var index = start
|
||||
for _, c := range sql {
|
||||
if !beginSingleQuote && c == '?' {
|
||||
if !beginSingleQuote && !isLineComment && !isComment && c == '?' {
|
||||
buf.WriteString(fmt.Sprintf("%s%v", prefix, index))
|
||||
index++
|
||||
} else {
|
||||
if c == '\'' {
|
||||
if isMaybeLineComment {
|
||||
if c == '-' {
|
||||
isLineComment = true
|
||||
}
|
||||
isMaybeLineComment = false
|
||||
} else if isMaybeComment {
|
||||
if c == '*' {
|
||||
isComment = true
|
||||
}
|
||||
isMaybeComment = false
|
||||
} else if isMaybeCommentEnd {
|
||||
if c == '/' {
|
||||
isComment = false
|
||||
}
|
||||
isMaybeCommentEnd = false
|
||||
} else if isLineComment {
|
||||
if c == '\n' {
|
||||
isLineComment = false
|
||||
}
|
||||
} else if isComment {
|
||||
if c == '*' {
|
||||
isMaybeCommentEnd = true
|
||||
}
|
||||
} else if !beginSingleQuote && c == '-' {
|
||||
isMaybeLineComment = true
|
||||
} else if !beginSingleQuote && c == '/' {
|
||||
isMaybeComment = true
|
||||
} else if c == '\'' {
|
||||
beginSingleQuote = !beginSingleQuote
|
||||
}
|
||||
buf.WriteRune(c)
|
||||
|
|
|
@ -19,3 +19,60 @@ func TestSeqFilter(t *testing.T) {
|
|||
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqFilterLineComment(t *testing.T) {
|
||||
var kases = map[string]string{
|
||||
`SELECT *
|
||||
FROM TABLE1
|
||||
WHERE foo='bar'
|
||||
AND a=? -- it's a comment
|
||||
AND b=?`: `SELECT *
|
||||
FROM TABLE1
|
||||
WHERE foo='bar'
|
||||
AND a=$1 -- it's a comment
|
||||
AND b=$2`,
|
||||
`SELECT *
|
||||
FROM TABLE1
|
||||
WHERE foo='bar'
|
||||
AND a=? -- it's a comment?
|
||||
AND b=?`: `SELECT *
|
||||
FROM TABLE1
|
||||
WHERE foo='bar'
|
||||
AND a=$1 -- it's a comment?
|
||||
AND b=$2`,
|
||||
`SELECT *
|
||||
FROM TABLE1
|
||||
WHERE a=? -- it's a comment? and that's okay?
|
||||
AND b=?`: `SELECT *
|
||||
FROM TABLE1
|
||||
WHERE a=$1 -- it's a comment? and that's okay?
|
||||
AND b=$2`,
|
||||
}
|
||||
for sql, result := range kases {
|
||||
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeqFilterComment(t *testing.T) {
|
||||
var kases = map[string]string{
|
||||
`SELECT *
|
||||
FROM TABLE1
|
||||
WHERE a=? /* it's a comment */
|
||||
AND b=?`: `SELECT *
|
||||
FROM TABLE1
|
||||
WHERE a=$1 /* it's a comment */
|
||||
AND b=$2`,
|
||||
`SELECT /* it's a comment * ?
|
||||
More comment on the next line! */ *
|
||||
FROM TABLE1
|
||||
WHERE a=? /**/
|
||||
AND b=?`: `SELECT /* it's a comment * ?
|
||||
More comment on the next line! */ *
|
||||
FROM TABLE1
|
||||
WHERE a=$1 /**/
|
||||
AND b=$2`,
|
||||
}
|
||||
for sql, result := range kases {
|
||||
assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1044,12 +1044,13 @@ func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tab
|
|||
|
||||
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) {
|
||||
args := []interface{}{tableName}
|
||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length,
|
||||
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, description,
|
||||
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
|
||||
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
|
||||
FROM pg_attribute f
|
||||
JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid
|
||||
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
|
||||
LEFT JOIN pg_description de ON f.attrelid=de.objoid AND f.attnum=de.objsubid
|
||||
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
|
||||
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
|
||||
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
|
||||
|
@ -1078,9 +1079,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
|
|||
col.Indexes = make(map[string]int)
|
||||
|
||||
var colName, isNullable, dataType string
|
||||
var maxLenStr, colDefault *string
|
||||
var maxLenStr, colDefault, description *string
|
||||
var isPK, isUnique bool
|
||||
err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &isPK, &isUnique)
|
||||
err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &description, &isPK, &isUnique)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -1126,6 +1127,10 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
|
|||
col.DefaultIsEmpty = true
|
||||
}
|
||||
|
||||
if description != nil {
|
||||
col.Comment = *description
|
||||
}
|
||||
|
||||
if isPK {
|
||||
col.IsPrimaryKey = true
|
||||
}
|
||||
|
|
|
@ -652,11 +652,7 @@ func (engine *Engine) dumpTables(tables []*schemas.Table, w io.Writer, tp ...sch
|
|||
return errors.New("unknown column error")
|
||||
}
|
||||
|
||||
fields := strings.Split(col.FieldName, ".")
|
||||
field := dataStruct
|
||||
for _, fieldName := range fields {
|
||||
field = field.FieldByName(fieldName)
|
||||
}
|
||||
field := dataStruct.FieldByIndex(col.FieldIndex)
|
||||
temp += "," + formatColumnValue(dstDialect, field.Interface(), col)
|
||||
}
|
||||
_, err = io.WriteString(w, temp[1:]+");\n")
|
||||
|
|
|
@ -176,6 +176,23 @@ func TestDumpTables(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDumpTables2(t *testing.T) {
|
||||
assert.NoError(t, PrepareEngine())
|
||||
|
||||
type TestDumpTableStruct2 struct {
|
||||
Id int64
|
||||
Created time.Time `xorm:"Default CURRENT_TIMESTAMP"`
|
||||
}
|
||||
|
||||
assertSync(t, new(TestDumpTableStruct2))
|
||||
|
||||
fp := fmt.Sprintf("./dump2-%v-table.sql", testEngine.Dialect().URI().DBType)
|
||||
os.Remove(fp)
|
||||
tb, err := testEngine.TableInfo(new(TestDumpTableStruct2))
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, testEngine.(*xorm.Engine).DumpTablesToFile([]*schemas.Table{tb}, fp))
|
||||
}
|
||||
|
||||
func TestSetSchema(t *testing.T) {
|
||||
assert.NoError(t, PrepareEngine())
|
||||
|
||||
|
@ -209,3 +226,39 @@ func TestDBVersion(t *testing.T) {
|
|||
|
||||
fmt.Println(testEngine.Dialect().URI().DBType, "version is", version)
|
||||
}
|
||||
|
||||
func TestGetColumns(t *testing.T) {
|
||||
if testEngine.Dialect().URI().DBType != schemas.POSTGRES {
|
||||
t.Skip()
|
||||
return
|
||||
}
|
||||
type TestCommentStruct struct {
|
||||
HasComment int
|
||||
NoComment int
|
||||
}
|
||||
|
||||
assertSync(t, new(TestCommentStruct))
|
||||
|
||||
comment := "this is a comment"
|
||||
sql := fmt.Sprintf("comment on column %s.%s is '%s'", testEngine.TableName(new(TestCommentStruct), true), "has_comment", comment)
|
||||
_, err := testEngine.Exec(sql)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tables, err := testEngine.DBMetas()
|
||||
assert.NoError(t, err)
|
||||
tableName := testEngine.GetColumnMapper().Obj2Table("TestCommentStruct")
|
||||
var hasComment, noComment string
|
||||
for _, table := range tables {
|
||||
if table.Name == tableName {
|
||||
col := table.GetColumn("has_comment")
|
||||
assert.NotNil(t, col)
|
||||
hasComment = col.Comment
|
||||
col2 := table.GetColumn("no_comment")
|
||||
assert.NotNil(t, col2)
|
||||
noComment = col2.Comment
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.Equal(t, comment, hasComment)
|
||||
assert.Zero(t, noComment)
|
||||
}
|
||||
|
|
|
@ -32,7 +32,6 @@ func TestInsertOne(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInsertMulti(t *testing.T) {
|
||||
|
||||
assert.NoError(t, PrepareEngine())
|
||||
type TestMulti struct {
|
||||
Id int64 `xorm:"int(11) pk"`
|
||||
|
@ -78,7 +77,6 @@ func insertMultiDatas(step int, datas interface{}) (num int64, err error) {
|
|||
}
|
||||
|
||||
func callbackLooper(datas interface{}, step int, actionFunc func(interface{}) error) (err error) {
|
||||
|
||||
sliceValue := reflect.Indirect(reflect.ValueOf(datas))
|
||||
if sliceValue.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("not slice")
|
||||
|
|
|
@ -472,6 +472,11 @@ func TestUpdateIncrDecr(t *testing.T) {
|
|||
cnt, err = testEngine.ID(col1.Id).Cols(colName).Incr(colName).Update(col1)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
|
||||
testEngine.SetColumnMapper(testEngine.GetColumnMapper())
|
||||
cnt, err = testEngine.Cols(colName).Decr(colName, 2).ID(col1.Id).Update(new(UpdateIncr))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, cnt)
|
||||
}
|
||||
|
||||
type UpdatedUpdate struct {
|
||||
|
|
|
@ -27,6 +27,7 @@ type Expr struct {
|
|||
Arg interface{}
|
||||
}
|
||||
|
||||
// WriteArgs writes args to the writer
|
||||
func (expr *Expr) WriteArgs(w *builder.BytesWriter) error {
|
||||
switch arg := expr.Arg.(type) {
|
||||
case *builder.Builder:
|
||||
|
|
|
@ -343,7 +343,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
|
|||
var args []interface{}
|
||||
var joinStr string
|
||||
var err error
|
||||
var b interface{} = nil
|
||||
var b interface{}
|
||||
if len(bean) > 0 {
|
||||
b = bean[0]
|
||||
beanValue := reflect.ValueOf(bean[0])
|
||||
|
|
|
@ -208,20 +208,18 @@ func (statement *Statement) quote(s string) string {
|
|||
|
||||
// And add Where & and statement
|
||||
func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
|
||||
switch query.(type) {
|
||||
switch qr := query.(type) {
|
||||
case string:
|
||||
cond := builder.Expr(query.(string), args...)
|
||||
cond := builder.Expr(qr, args...)
|
||||
statement.cond = statement.cond.And(cond)
|
||||
case map[string]interface{}:
|
||||
queryMap := query.(map[string]interface{})
|
||||
newMap := make(map[string]interface{})
|
||||
for k, v := range queryMap {
|
||||
newMap[statement.quote(k)] = v
|
||||
cond := make(builder.Eq)
|
||||
for k, v := range qr {
|
||||
cond[statement.quote(k)] = v
|
||||
}
|
||||
statement.cond = statement.cond.And(builder.Eq(newMap))
|
||||
case builder.Cond:
|
||||
cond := query.(builder.Cond)
|
||||
statement.cond = statement.cond.And(cond)
|
||||
case builder.Cond:
|
||||
statement.cond = statement.cond.And(qr)
|
||||
for _, v := range args {
|
||||
if vv, ok := v.(builder.Cond); ok {
|
||||
statement.cond = statement.cond.And(vv)
|
||||
|
@ -236,23 +234,25 @@ func (statement *Statement) And(query interface{}, args ...interface{}) *Stateme
|
|||
|
||||
// Or add Where & Or statement
|
||||
func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
|
||||
switch query.(type) {
|
||||
switch qr := query.(type) {
|
||||
case string:
|
||||
cond := builder.Expr(query.(string), args...)
|
||||
cond := builder.Expr(qr, args...)
|
||||
statement.cond = statement.cond.Or(cond)
|
||||
case map[string]interface{}:
|
||||
cond := builder.Eq(query.(map[string]interface{}))
|
||||
cond := make(builder.Eq)
|
||||
for k, v := range qr {
|
||||
cond[statement.quote(k)] = v
|
||||
}
|
||||
statement.cond = statement.cond.Or(cond)
|
||||
case builder.Cond:
|
||||
cond := query.(builder.Cond)
|
||||
statement.cond = statement.cond.Or(cond)
|
||||
statement.cond = statement.cond.Or(qr)
|
||||
for _, v := range args {
|
||||
if vv, ok := v.(builder.Cond); ok {
|
||||
statement.cond = statement.cond.Or(vv)
|
||||
}
|
||||
}
|
||||
default:
|
||||
// TODO: not support condition type
|
||||
statement.LastError = ErrConditionType
|
||||
}
|
||||
return statement
|
||||
}
|
||||
|
@ -734,6 +734,8 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
|
|||
//engine.logger.Warn(err)
|
||||
}
|
||||
continue
|
||||
} else if fieldValuePtr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if col.IsDeleted && !unscoped { // tag "deleted" is enabled
|
||||
|
|
|
@ -78,7 +78,6 @@ func TestColumnsStringGeneration(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
mapCols := make(map[string]bool)
|
||||
|
@ -101,9 +100,7 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
|||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
for _, col := range cols {
|
||||
|
||||
if _, ok := getFlagForColumn(mapCols, col); !ok {
|
||||
b.Fatal("Unexpected result")
|
||||
}
|
||||
|
@ -112,7 +109,6 @@ func BenchmarkGetFlagForColumnWithICKey_ContainsKey(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
mapCols := make(map[string]bool)
|
||||
|
@ -131,9 +127,7 @@ func BenchmarkGetFlagForColumnWithICKey_EmptyMap(b *testing.B) {
|
|||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
for _, col := range cols {
|
||||
|
||||
if _, ok := getFlagForColumn(mapCols, col); ok {
|
||||
b.Fatal("Unexpected result")
|
||||
}
|
||||
|
|
|
@ -88,6 +88,9 @@ func (statement *Statement) BuildUpdates(tableValue reflect.Value,
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if fieldValuePtr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldValue := *fieldValuePtr
|
||||
fieldType := reflect.TypeOf(fieldValue.Interface())
|
||||
|
|
|
@ -132,7 +132,6 @@ func (s *SimpleLogger) Error(v ...interface{}) {
|
|||
if s.level <= LOG_ERR {
|
||||
s.ERR.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Errorf implement ILogger
|
||||
|
@ -140,7 +139,6 @@ func (s *SimpleLogger) Errorf(format string, v ...interface{}) {
|
|||
if s.level <= LOG_ERR {
|
||||
s.ERR.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Debug implement ILogger
|
||||
|
@ -148,7 +146,6 @@ func (s *SimpleLogger) Debug(v ...interface{}) {
|
|||
if s.level <= LOG_DEBUG {
|
||||
s.DEBUG.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Debugf implement ILogger
|
||||
|
@ -156,7 +153,6 @@ func (s *SimpleLogger) Debugf(format string, v ...interface{}) {
|
|||
if s.level <= LOG_DEBUG {
|
||||
s.DEBUG.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Info implement ILogger
|
||||
|
@ -164,7 +160,6 @@ func (s *SimpleLogger) Info(v ...interface{}) {
|
|||
if s.level <= LOG_INFO {
|
||||
s.INFO.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Infof implement ILogger
|
||||
|
@ -172,7 +167,6 @@ func (s *SimpleLogger) Infof(format string, v ...interface{}) {
|
|||
if s.level <= LOG_INFO {
|
||||
s.INFO.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Warn implement ILogger
|
||||
|
@ -180,7 +174,6 @@ func (s *SimpleLogger) Warn(v ...interface{}) {
|
|||
if s.level <= LOG_WARNING {
|
||||
s.WARN.Output(2, fmt.Sprintln(v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Warnf implement ILogger
|
||||
|
@ -188,7 +181,6 @@ func (s *SimpleLogger) Warnf(format string, v ...interface{}) {
|
|||
if s.level <= LOG_WARNING {
|
||||
s.WARN.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Level implement ILogger
|
||||
|
@ -199,7 +191,6 @@ func (s *SimpleLogger) Level() LogLevel {
|
|||
// SetLevel implement ILogger
|
||||
func (s *SimpleLogger) SetLevel(l LogLevel) {
|
||||
s.level = l
|
||||
return
|
||||
}
|
||||
|
||||
// ShowSQL implement ILogger
|
||||
|
|
|
@ -6,10 +6,8 @@ package schemas
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -25,6 +23,7 @@ type Column struct {
|
|||
Name string
|
||||
TableName string
|
||||
FieldName string // Available only when parsed from a struct
|
||||
FieldIndex []int // Available only when parsed from a struct
|
||||
SQLType SQLType
|
||||
IsJSON bool
|
||||
Length int
|
||||
|
@ -83,41 +82,17 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
|
|||
|
||||
// ValueOfV returns column's filed of struct's value accept reflevt value
|
||||
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
|
||||
var fieldValue reflect.Value
|
||||
fieldPath := strings.Split(col.FieldName, ".")
|
||||
|
||||
if dataStruct.Type().Kind() == reflect.Map {
|
||||
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1])
|
||||
fieldValue = dataStruct.MapIndex(keyValue)
|
||||
return &fieldValue, nil
|
||||
} else if dataStruct.Type().Kind() == reflect.Interface {
|
||||
structValue := reflect.ValueOf(dataStruct.Interface())
|
||||
dataStruct = &structValue
|
||||
}
|
||||
|
||||
level := len(fieldPath)
|
||||
fieldValue = dataStruct.FieldByName(fieldPath[0])
|
||||
for i := 0; i < level-1; i++ {
|
||||
if !fieldValue.IsValid() {
|
||||
break
|
||||
}
|
||||
if fieldValue.Kind() == reflect.Struct {
|
||||
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
|
||||
} else if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||
var v = *dataStruct
|
||||
for _, i := range col.FieldIndex {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
|
||||
} else {
|
||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
|
||||
v = v.Elem()
|
||||
}
|
||||
v = v.FieldByIndex([]int{i})
|
||||
}
|
||||
|
||||
if !fieldValue.IsValid() {
|
||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
|
||||
}
|
||||
|
||||
return &fieldValue, nil
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// ConvertID converts id content to suitable type according column type
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
package schemas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -159,24 +158,8 @@ func (table *Table) IDOfV(rv reflect.Value) (PK, error) {
|
|||
for i, col := range table.PKColumns() {
|
||||
var err error
|
||||
|
||||
fieldName := col.FieldName
|
||||
for {
|
||||
parts := strings.SplitN(fieldName, ".", 2)
|
||||
if len(parts) == 1 {
|
||||
break
|
||||
}
|
||||
pkField := v.FieldByIndex(col.FieldIndex)
|
||||
|
||||
v = v.FieldByName(parts[0])
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("Unsupported read value of column %s from field %s", col.Name, col.FieldName)
|
||||
}
|
||||
fieldName = parts[1]
|
||||
}
|
||||
|
||||
pkField := v.FieldByName(fieldName)
|
||||
switch pkField.Kind() {
|
||||
case reflect.String:
|
||||
pk[i], err = col.ConvertID(pkField.String())
|
||||
|
|
|
@ -27,7 +27,6 @@ var testsGetColumn = []struct {
|
|||
var table *Table
|
||||
|
||||
func init() {
|
||||
|
||||
table = NewEmptyTable()
|
||||
|
||||
var name string
|
||||
|
@ -41,7 +40,6 @@ func init() {
|
|||
}
|
||||
|
||||
func TestGetColumn(t *testing.T) {
|
||||
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumn(test.name) == nil {
|
||||
t.Error("Column not found!")
|
||||
|
@ -50,7 +48,6 @@ func TestGetColumn(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetColumnIdx(t *testing.T) {
|
||||
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumnIdx(test.name, test.idx) == nil {
|
||||
t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
|
||||
|
@ -59,7 +56,6 @@ func TestGetColumnIdx(t *testing.T) {
|
|||
}
|
||||
|
||||
func BenchmarkGetColumnWithToLower(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
|
||||
|
@ -71,7 +67,6 @@ func BenchmarkGetColumnWithToLower(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
|
||||
|
@ -89,7 +84,6 @@ func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkGetColumn(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumn(test.name) == nil {
|
||||
|
@ -100,7 +94,6 @@ func BenchmarkGetColumn(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkGetColumnIdx(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumnIdx(test.name, test.idx) == nil {
|
||||
|
|
|
@ -375,6 +375,9 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *s
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fieldValue == nil {
|
||||
return nil, ErrFieldIsNotValid{key, table.Name}
|
||||
}
|
||||
|
||||
if !fieldValue.IsValid() || !fieldValue.CanSet() {
|
||||
return nil, ErrFieldIsNotValid{key, table.Name}
|
||||
|
|
|
@ -35,27 +35,20 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
|
|||
sd, err := strconv.ParseInt(sdata, 10, 64)
|
||||
if err == nil {
|
||||
x = time.Unix(sd, 0)
|
||||
//session.engine.logger.Debugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
} else {
|
||||
//session.engine.logger.Debugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
}
|
||||
} else if len(sdata) > 19 && strings.Contains(sdata, "-") {
|
||||
x, err = time.ParseInLocation(time.RFC3339Nano, sdata, parseLoc)
|
||||
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
session.engine.logger.Debugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.Name, x, sdata)
|
||||
if err != nil {
|
||||
x, err = time.ParseInLocation("2006-01-02 15:04:05.999999999", sdata, parseLoc)
|
||||
//session.engine.logger.Debugf("time(2) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
}
|
||||
if err != nil {
|
||||
x, err = time.ParseInLocation("2006-01-02 15:04:05.9999999 Z07:00", sdata, parseLoc)
|
||||
//session.engine.logger.Debugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
}
|
||||
} else if len(sdata) == 19 && strings.Contains(sdata, "-") {
|
||||
x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, parseLoc)
|
||||
//session.engine.logger.Debugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
|
||||
x, err = time.ParseInLocation("2006-01-02", sdata, parseLoc)
|
||||
//session.engine.logger.Debugf("time(5) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
} else if col.SQLType.Name == schemas.Time {
|
||||
if strings.Contains(sdata, " ") {
|
||||
ssd := strings.Split(sdata, " ")
|
||||
|
@ -69,7 +62,6 @@ func (session *Session) str2Time(col *schemas.Column, data string) (outTime time
|
|||
|
||||
st := fmt.Sprintf("2006-01-02 %v", sdata)
|
||||
x, err = time.ParseInLocation("2006-01-02 15:04:05", st, parseLoc)
|
||||
//session.engine.logger.Debugf("time(6) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata)
|
||||
} else {
|
||||
outErr = fmt.Errorf("unsupported time format %v", sdata)
|
||||
return
|
||||
|
|
|
@ -374,9 +374,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
return 1, nil
|
||||
}
|
||||
|
||||
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
|
||||
|
||||
return 1, nil
|
||||
return 1, convertAssignV(aiValue.Addr(), id)
|
||||
} else if len(table.AutoIncrement) > 0 && (session.engine.dialect.URI().DBType == schemas.POSTGRES ||
|
||||
session.engine.dialect.URI().DBType == schemas.MSSQL) {
|
||||
res, err := session.queryBytes(sqlStr, args...)
|
||||
|
@ -416,9 +414,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
return 1, nil
|
||||
}
|
||||
|
||||
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
|
||||
|
||||
return 1, nil
|
||||
return 1, convertAssignV(aiValue.Addr(), id)
|
||||
}
|
||||
|
||||
res, err := session.exec(sqlStr, args...)
|
||||
|
@ -458,7 +454,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
|
|||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
aiValue.Set(int64ToIntValue(id, aiValue.Type()))
|
||||
if err := convertAssignV(aiValue.Addr(), id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
|
|
@ -280,15 +280,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
k = ct.Elem().Kind()
|
||||
}
|
||||
if k == reflect.Struct {
|
||||
var refTable = session.statement.RefTable
|
||||
if refTable == nil {
|
||||
refTable, err = session.engine.TableInfo(condiBean[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
condTable, err := session.engine.TableInfo(condiBean[0])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var err error
|
||||
autoCond, err = session.statement.BuildConds(refTable, condiBean[0], true, true, false, true, false)
|
||||
|
||||
autoCond, err = session.statement.BuildConds(condTable, condiBean[0], true, true, false, true, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -457,7 +454,6 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
// FIXME: if bean is a map type, it will panic because map cannot be as map key
|
||||
session.afterUpdateBeans[bean] = &afterClosures
|
||||
}
|
||||
|
||||
} else {
|
||||
if _, ok := interface{}(bean).(AfterUpdateProcessor); ok {
|
||||
session.afterUpdateBeans[bean] = nil
|
||||
|
|
330
tags/parser.go
330
tags/parser.go
|
@ -7,7 +7,6 @@ package tags
|
|||
import (
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -23,7 +22,7 @@ import (
|
|||
|
||||
var (
|
||||
// ErrUnsupportedType represents an unsupported type error
|
||||
ErrUnsupportedType = errors.New("Unsupported type")
|
||||
ErrUnsupportedType = errors.New("unsupported type")
|
||||
)
|
||||
|
||||
// Parser represents a parser for xorm tag
|
||||
|
@ -125,6 +124,147 @@ func addIndex(indexName string, table *schemas.Table, col *schemas.Column, index
|
|||
}
|
||||
}
|
||||
|
||||
var ErrIgnoreField = errors.New("field will be ignored")
|
||||
|
||||
func (parser *Parser) parseFieldWithNoTag(fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
||||
var sqlType schemas.SQLType
|
||||
if fieldValue.CanAddr() {
|
||||
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
|
||||
sqlType = schemas.SQLType{Name: schemas.Text}
|
||||
}
|
||||
}
|
||||
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||
sqlType = schemas.SQLType{Name: schemas.Text}
|
||||
} else {
|
||||
sqlType = schemas.Type2SQLType(field.Type)
|
||||
}
|
||||
col := schemas.NewColumn(parser.columnMapper.Obj2Table(field.Name),
|
||||
field.Name, sqlType, sqlType.DefaultLength,
|
||||
sqlType.DefaultLength2, true)
|
||||
col.FieldIndex = []int{fieldIndex}
|
||||
|
||||
if field.Type.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
|
||||
col.IsAutoIncrement = true
|
||||
col.IsPrimaryKey = true
|
||||
col.Nullable = false
|
||||
}
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func (parser *Parser) parseFieldWithTags(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value, tags []tag) (*schemas.Column, error) {
|
||||
var col = &schemas.Column{
|
||||
FieldName: field.Name,
|
||||
FieldIndex: []int{fieldIndex},
|
||||
Nullable: true,
|
||||
IsPrimaryKey: false,
|
||||
IsAutoIncrement: false,
|
||||
MapType: schemas.TWOSIDES,
|
||||
Indexes: make(map[string]int),
|
||||
DefaultIsEmpty: true,
|
||||
}
|
||||
|
||||
var ctx = Context{
|
||||
table: table,
|
||||
col: col,
|
||||
fieldValue: fieldValue,
|
||||
indexNames: make(map[string]int),
|
||||
parser: parser,
|
||||
}
|
||||
|
||||
for j, tag := range tags {
|
||||
if ctx.ignoreNext {
|
||||
ctx.ignoreNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
ctx.tag = tag
|
||||
ctx.tagUname = strings.ToUpper(tag.name)
|
||||
|
||||
if j > 0 {
|
||||
ctx.preTag = strings.ToUpper(tags[j-1].name)
|
||||
}
|
||||
if j < len(tags)-1 {
|
||||
ctx.nextTag = tags[j+1].name
|
||||
} else {
|
||||
ctx.nextTag = ""
|
||||
}
|
||||
|
||||
if h, ok := parser.handlers[ctx.tagUname]; ok {
|
||||
if err := h(&ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(ctx.tag.name, "'") && strings.HasSuffix(ctx.tag.name, "'") {
|
||||
col.Name = ctx.tag.name[1 : len(ctx.tag.name)-1]
|
||||
} else {
|
||||
col.Name = ctx.tag.name
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.hasCacheTag {
|
||||
if parser.cacherMgr.GetDefaultCacher() != nil {
|
||||
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
|
||||
} else {
|
||||
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
|
||||
}
|
||||
}
|
||||
if ctx.hasNoCacheTag {
|
||||
parser.cacherMgr.SetCacher(table.Name, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if col.SQLType.Name == "" {
|
||||
col.SQLType = schemas.Type2SQLType(field.Type)
|
||||
}
|
||||
parser.dialect.SQLType(col)
|
||||
if col.Length == 0 {
|
||||
col.Length = col.SQLType.DefaultLength
|
||||
}
|
||||
if col.Length2 == 0 {
|
||||
col.Length2 = col.SQLType.DefaultLength2
|
||||
}
|
||||
if col.Name == "" {
|
||||
col.Name = parser.columnMapper.Obj2Table(field.Name)
|
||||
}
|
||||
|
||||
if ctx.isUnique {
|
||||
ctx.indexNames[col.Name] = schemas.UniqueType
|
||||
} else if ctx.isIndex {
|
||||
ctx.indexNames[col.Name] = schemas.IndexType
|
||||
}
|
||||
|
||||
for indexName, indexType := range ctx.indexNames {
|
||||
addIndex(indexName, table, col, indexType)
|
||||
}
|
||||
|
||||
return col, nil
|
||||
}
|
||||
|
||||
func (parser *Parser) parseField(table *schemas.Table, fieldIndex int, field reflect.StructField, fieldValue reflect.Value) (*schemas.Column, error) {
|
||||
var (
|
||||
tag = field.Tag
|
||||
ormTagStr = strings.TrimSpace(tag.Get(parser.identifier))
|
||||
)
|
||||
if ormTagStr == "-" {
|
||||
return nil, ErrIgnoreField
|
||||
}
|
||||
if ormTagStr == "" {
|
||||
return parser.parseFieldWithNoTag(fieldIndex, field, fieldValue)
|
||||
}
|
||||
tags, err := splitTag(ormTagStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parser.parseFieldWithTags(table, fieldIndex, field, fieldValue, tags)
|
||||
}
|
||||
|
||||
func isNotTitle(n string) bool {
|
||||
for _, c := range n {
|
||||
return unicode.IsLower(c)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Parse parses a struct as a table information
|
||||
func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
||||
t := v.Type()
|
||||
|
@ -140,193 +280,21 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
|
|||
table.Type = t
|
||||
table.Name = names.GetTableName(parser.tableMapper, v)
|
||||
|
||||
var idFieldColName string
|
||||
var hasCacheTag, hasNoCacheTag bool
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
var isUnexportField bool
|
||||
for _, c := range t.Field(i).Name {
|
||||
if unicode.IsLower(c) {
|
||||
isUnexportField = true
|
||||
}
|
||||
break
|
||||
}
|
||||
if isUnexportField {
|
||||
var field = t.Field(i)
|
||||
if isNotTitle(field.Name) {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := t.Field(i).Tag
|
||||
ormTagStr := tag.Get(parser.identifier)
|
||||
var col *schemas.Column
|
||||
fieldValue := v.Field(i)
|
||||
fieldType := fieldValue.Type()
|
||||
|
||||
if ormTagStr != "" {
|
||||
col = &schemas.Column{
|
||||
FieldName: t.Field(i).Name,
|
||||
Nullable: true,
|
||||
IsPrimaryKey: false,
|
||||
IsAutoIncrement: false,
|
||||
MapType: schemas.TWOSIDES,
|
||||
Indexes: make(map[string]int),
|
||||
DefaultIsEmpty: true,
|
||||
}
|
||||
tags := splitTag(ormTagStr)
|
||||
|
||||
if len(tags) > 0 {
|
||||
if tags[0] == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
var ctx = Context{
|
||||
table: table,
|
||||
col: col,
|
||||
fieldValue: fieldValue,
|
||||
indexNames: make(map[string]int),
|
||||
parser: parser,
|
||||
}
|
||||
|
||||
if strings.HasPrefix(strings.ToUpper(tags[0]), "EXTENDS") {
|
||||
pStart := strings.Index(tags[0], "(")
|
||||
if pStart > -1 && strings.HasSuffix(tags[0], ")") {
|
||||
var tagPrefix = strings.TrimFunc(tags[0][pStart+1:len(tags[0])-1], func(r rune) bool {
|
||||
return r == '\'' || r == '"'
|
||||
})
|
||||
|
||||
ctx.params = []string{tagPrefix}
|
||||
}
|
||||
|
||||
if err := ExtendsTagHandler(&ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for j, key := range tags {
|
||||
if ctx.ignoreNext {
|
||||
ctx.ignoreNext = false
|
||||
continue
|
||||
}
|
||||
|
||||
k := strings.ToUpper(key)
|
||||
ctx.tagName = k
|
||||
ctx.params = []string{}
|
||||
|
||||
pStart := strings.Index(k, "(")
|
||||
if pStart == 0 {
|
||||
return nil, errors.New("( could not be the first character")
|
||||
}
|
||||
if pStart > -1 {
|
||||
if !strings.HasSuffix(k, ")") {
|
||||
return nil, fmt.Errorf("field %s tag %s cannot match ) character", col.FieldName, key)
|
||||
}
|
||||
|
||||
ctx.tagName = k[:pStart]
|
||||
ctx.params = strings.Split(key[pStart+1:len(k)-1], ",")
|
||||
}
|
||||
|
||||
if j > 0 {
|
||||
ctx.preTag = strings.ToUpper(tags[j-1])
|
||||
}
|
||||
if j < len(tags)-1 {
|
||||
ctx.nextTag = tags[j+1]
|
||||
} else {
|
||||
ctx.nextTag = ""
|
||||
}
|
||||
|
||||
if h, ok := parser.handlers[ctx.tagName]; ok {
|
||||
if err := h(&ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
|
||||
col.Name = key[1 : len(key)-1]
|
||||
} else {
|
||||
col.Name = key
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.hasCacheTag {
|
||||
hasCacheTag = true
|
||||
}
|
||||
if ctx.hasNoCacheTag {
|
||||
hasNoCacheTag = true
|
||||
}
|
||||
}
|
||||
|
||||
if col.SQLType.Name == "" {
|
||||
col.SQLType = schemas.Type2SQLType(fieldType)
|
||||
}
|
||||
parser.dialect.SQLType(col)
|
||||
if col.Length == 0 {
|
||||
col.Length = col.SQLType.DefaultLength
|
||||
}
|
||||
if col.Length2 == 0 {
|
||||
col.Length2 = col.SQLType.DefaultLength2
|
||||
}
|
||||
if col.Name == "" {
|
||||
col.Name = parser.columnMapper.Obj2Table(t.Field(i).Name)
|
||||
}
|
||||
|
||||
if ctx.isUnique {
|
||||
ctx.indexNames[col.Name] = schemas.UniqueType
|
||||
} else if ctx.isIndex {
|
||||
ctx.indexNames[col.Name] = schemas.IndexType
|
||||
}
|
||||
|
||||
for indexName, indexType := range ctx.indexNames {
|
||||
addIndex(indexName, table, col, indexType)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var sqlType schemas.SQLType
|
||||
if fieldValue.CanAddr() {
|
||||
if _, ok := fieldValue.Addr().Interface().(convert.Conversion); ok {
|
||||
sqlType = schemas.SQLType{Name: schemas.Text}
|
||||
}
|
||||
}
|
||||
if _, ok := fieldValue.Interface().(convert.Conversion); ok {
|
||||
sqlType = schemas.SQLType{Name: schemas.Text}
|
||||
} else {
|
||||
sqlType = schemas.Type2SQLType(fieldType)
|
||||
}
|
||||
col = schemas.NewColumn(parser.columnMapper.Obj2Table(t.Field(i).Name),
|
||||
t.Field(i).Name, sqlType, sqlType.DefaultLength,
|
||||
sqlType.DefaultLength2, true)
|
||||
|
||||
if fieldType.Kind() == reflect.Int64 && (strings.ToUpper(col.FieldName) == "ID" || strings.HasSuffix(strings.ToUpper(col.FieldName), ".ID")) {
|
||||
idFieldColName = col.Name
|
||||
}
|
||||
}
|
||||
if col.IsAutoIncrement {
|
||||
col.Nullable = false
|
||||
col, err := parser.parseField(table, i, field, v.Field(i))
|
||||
if err == ErrIgnoreField {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
table.AddColumn(col)
|
||||
} // end for
|
||||
|
||||
if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
|
||||
col := table.GetColumn(idFieldColName)
|
||||
col.IsPrimaryKey = true
|
||||
col.IsAutoIncrement = true
|
||||
col.Nullable = false
|
||||
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
|
||||
table.AutoIncrement = col.Name
|
||||
}
|
||||
|
||||
if hasCacheTag {
|
||||
if parser.cacherMgr.GetDefaultCacher() != nil { // !nash! use engine's cacher if provided
|
||||
//engine.logger.Info("enable cache on table:", table.Name)
|
||||
parser.cacherMgr.SetCacher(table.Name, parser.cacherMgr.GetDefaultCacher())
|
||||
} else {
|
||||
//engine.logger.Info("enable LRU cache on table:", table.Name)
|
||||
parser.cacherMgr.SetCacher(table.Name, caches.NewLRUCacher2(caches.NewMemoryStore(), time.Hour, 10000))
|
||||
}
|
||||
}
|
||||
if hasNoCacheTag {
|
||||
//engine.logger.Info("disable cache on table:", table.Name)
|
||||
parser.cacherMgr.SetCacher(table.Name, nil)
|
||||
}
|
||||
|
||||
return table, nil
|
||||
}
|
||||
|
|
|
@ -6,12 +6,16 @@ package tags
|
|||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"xorm.io/xorm/caches"
|
||||
"xorm.io/xorm/dialects"
|
||||
"xorm.io/xorm/names"
|
||||
"xorm.io/xorm/schemas"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type ParseTableName1 struct{}
|
||||
|
@ -80,7 +84,7 @@ func TestParseWithOtherIdentifier(t *testing.T) {
|
|||
parser := NewParser(
|
||||
"xorm",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.GonicMapper{},
|
||||
names.SameMapper{},
|
||||
names.SnakeMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
@ -88,13 +92,461 @@ func TestParseWithOtherIdentifier(t *testing.T) {
|
|||
type StructWithDBTag struct {
|
||||
FieldFoo string `db:"foo"`
|
||||
}
|
||||
|
||||
parser.SetIdentifier("db")
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithDBTag)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_db_tag", table.Name)
|
||||
assert.EqualValues(t, "StructWithDBTag", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
|
||||
for _, col := range table.Columns() {
|
||||
assert.EqualValues(t, "foo", col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWithIgnore(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SameMapper{},
|
||||
names.SnakeMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithIgnoreTag struct {
|
||||
FieldFoo string `db:"-"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithIgnoreTag)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "StructWithIgnoreTag", table.Name)
|
||||
assert.EqualValues(t, 0, len(table.Columns()))
|
||||
}
|
||||
|
||||
func TestParseWithAutoincrement(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithAutoIncrement struct {
|
||||
ID int64
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_auto_increment", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "id", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].IsAutoIncrement)
|
||||
assert.True(t, table.Columns()[0].IsPrimaryKey)
|
||||
}
|
||||
|
||||
func TestParseWithAutoincrement2(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithAutoIncrement2 struct {
|
||||
ID int64 `db:"pk autoincr"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithAutoIncrement2)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_auto_increment2", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "id", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].IsAutoIncrement)
|
||||
assert.True(t, table.Columns()[0].IsPrimaryKey)
|
||||
assert.False(t, table.Columns()[0].Nullable)
|
||||
}
|
||||
|
||||
func TestParseWithNullable(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithNullable struct {
|
||||
Name string `db:"notnull"`
|
||||
FullName string `db:"null comment('column comment,字段注释')"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithNullable)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_nullable", table.Name)
|
||||
assert.EqualValues(t, 2, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "full_name", table.Columns()[1].Name)
|
||||
assert.False(t, table.Columns()[0].Nullable)
|
||||
assert.True(t, table.Columns()[1].Nullable)
|
||||
assert.EqualValues(t, "column comment,字段注释", table.Columns()[1].Comment)
|
||||
}
|
||||
|
||||
func TestParseWithTimes(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithTimes struct {
|
||||
Name string `db:"notnull"`
|
||||
CreatedAt time.Time `db:"created"`
|
||||
UpdatedAt time.Time `db:"updated"`
|
||||
DeletedAt time.Time `db:"deleted"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithTimes)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_times", table.Name)
|
||||
assert.EqualValues(t, 4, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
|
||||
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
|
||||
assert.False(t, table.Columns()[0].Nullable)
|
||||
assert.True(t, table.Columns()[1].Nullable)
|
||||
assert.True(t, table.Columns()[1].IsCreated)
|
||||
assert.True(t, table.Columns()[2].Nullable)
|
||||
assert.True(t, table.Columns()[2].IsUpdated)
|
||||
assert.True(t, table.Columns()[3].Nullable)
|
||||
assert.True(t, table.Columns()[3].IsDeleted)
|
||||
}
|
||||
|
||||
func TestParseWithExtends(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithEmbed struct {
|
||||
Name string
|
||||
CreatedAt time.Time `db:"created"`
|
||||
UpdatedAt time.Time `db:"updated"`
|
||||
DeletedAt time.Time `db:"deleted"`
|
||||
}
|
||||
|
||||
type StructWithExtends struct {
|
||||
SW StructWithEmbed `db:"extends"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithExtends)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_extends", table.Name)
|
||||
assert.EqualValues(t, 4, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "created_at", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, "updated_at", table.Columns()[2].Name)
|
||||
assert.EqualValues(t, "deleted_at", table.Columns()[3].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
assert.True(t, table.Columns()[1].Nullable)
|
||||
assert.True(t, table.Columns()[1].IsCreated)
|
||||
assert.True(t, table.Columns()[2].Nullable)
|
||||
assert.True(t, table.Columns()[2].IsUpdated)
|
||||
assert.True(t, table.Columns()[3].Nullable)
|
||||
assert.True(t, table.Columns()[3].IsDeleted)
|
||||
}
|
||||
|
||||
func TestParseWithCache(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithCache struct {
|
||||
Name string `db:"cache"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithCache)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_cache", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
cacher := parser.cacherMgr.GetCacher(table.Name)
|
||||
assert.NotNil(t, cacher)
|
||||
}
|
||||
|
||||
func TestParseWithNoCache(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithNoCache struct {
|
||||
Name string `db:"nocache"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithNoCache)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_no_cache", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
cacher := parser.cacherMgr.GetCacher(table.Name)
|
||||
assert.Nil(t, cacher)
|
||||
}
|
||||
|
||||
func TestParseWithEnum(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithEnum struct {
|
||||
Name string `db:"enum('alice', 'bob')"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithEnum)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_enum", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
assert.EqualValues(t, schemas.Enum, strings.ToUpper(table.Columns()[0].SQLType.Name))
|
||||
assert.EqualValues(t, map[string]int{
|
||||
"alice": 0,
|
||||
"bob": 1,
|
||||
}, table.Columns()[0].EnumOptions)
|
||||
}
|
||||
|
||||
func TestParseWithSet(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithSet struct {
|
||||
Name string `db:"set('alice', 'bob')"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithSet)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_set", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
assert.EqualValues(t, schemas.Set, strings.ToUpper(table.Columns()[0].SQLType.Name))
|
||||
assert.EqualValues(t, map[string]int{
|
||||
"alice": 0,
|
||||
"bob": 1,
|
||||
}, table.Columns()[0].SetOptions)
|
||||
}
|
||||
|
||||
func TestParseWithIndex(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithIndex struct {
|
||||
Name string `db:"index"`
|
||||
Name2 string `db:"index(s)"`
|
||||
Name3 string `db:"unique"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithIndex)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_index", table.Name)
|
||||
assert.EqualValues(t, 3, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "name2", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, "name3", table.Columns()[2].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
assert.True(t, table.Columns()[1].Nullable)
|
||||
assert.True(t, table.Columns()[2].Nullable)
|
||||
assert.EqualValues(t, 1, len(table.Columns()[0].Indexes))
|
||||
assert.EqualValues(t, 1, len(table.Columns()[1].Indexes))
|
||||
assert.EqualValues(t, 1, len(table.Columns()[2].Indexes))
|
||||
}
|
||||
|
||||
func TestParseWithVersion(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithVersion struct {
|
||||
Name string
|
||||
Version int `db:"version"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithVersion)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_version", table.Name)
|
||||
assert.EqualValues(t, 2, len(table.Columns()))
|
||||
assert.EqualValues(t, "name", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "version", table.Columns()[1].Name)
|
||||
assert.True(t, table.Columns()[0].Nullable)
|
||||
assert.True(t, table.Columns()[1].Nullable)
|
||||
assert.True(t, table.Columns()[1].IsVersion)
|
||||
}
|
||||
|
||||
func TestParseWithLocale(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithLocale struct {
|
||||
UTCLocale time.Time `db:"utc"`
|
||||
LocalLocale time.Time `db:"local"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithLocale)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_locale", table.Name)
|
||||
assert.EqualValues(t, 2, len(table.Columns()))
|
||||
assert.EqualValues(t, "utc_locale", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "local_locale", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, time.UTC, table.Columns()[0].TimeZone)
|
||||
assert.EqualValues(t, time.Local, table.Columns()[1].TimeZone)
|
||||
}
|
||||
|
||||
func TestParseWithDefault(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.SnakeMapper{},
|
||||
names.GonicMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithDefault struct {
|
||||
Default1 time.Time `db:"default '1970-01-01 00:00:00'"`
|
||||
Default2 time.Time `db:"default(CURRENT_TIMESTAMP)"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithDefault)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_default", table.Name)
|
||||
assert.EqualValues(t, 2, len(table.Columns()))
|
||||
assert.EqualValues(t, "default1", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "default2", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, "'1970-01-01 00:00:00'", table.Columns()[0].Default)
|
||||
assert.EqualValues(t, "CURRENT_TIMESTAMP", table.Columns()[1].Default)
|
||||
}
|
||||
|
||||
func TestParseWithOnlyToDB(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.GonicMapper{
|
||||
"DB": true,
|
||||
},
|
||||
names.SnakeMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithOnlyToDB struct {
|
||||
Default1 time.Time `db:"->"`
|
||||
Default2 time.Time `db:"<-"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithOnlyToDB)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_only_to_db", table.Name)
|
||||
assert.EqualValues(t, 2, len(table.Columns()))
|
||||
assert.EqualValues(t, "default1", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "default2", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, schemas.ONLYTODB, table.Columns()[0].MapType)
|
||||
assert.EqualValues(t, schemas.ONLYFROMDB, table.Columns()[1].MapType)
|
||||
}
|
||||
|
||||
func TestParseWithJSON(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.GonicMapper{
|
||||
"JSON": true,
|
||||
},
|
||||
names.SnakeMapper{},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithJSON struct {
|
||||
Default1 []string `db:"json"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithJSON)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_json", table.Name)
|
||||
assert.EqualValues(t, 1, len(table.Columns()))
|
||||
assert.EqualValues(t, "default1", table.Columns()[0].Name)
|
||||
assert.True(t, table.Columns()[0].IsJSON)
|
||||
}
|
||||
|
||||
func TestParseWithSQLType(t *testing.T) {
|
||||
parser := NewParser(
|
||||
"db",
|
||||
dialects.QueryDialect("mysql"),
|
||||
names.GonicMapper{
|
||||
"SQL": true,
|
||||
},
|
||||
names.GonicMapper{
|
||||
"UUID": true,
|
||||
},
|
||||
caches.NewManager(),
|
||||
)
|
||||
|
||||
type StructWithSQLType struct {
|
||||
Col1 string `db:"varchar(32)"`
|
||||
Col2 string `db:"char(32)"`
|
||||
Int int64 `db:"bigint"`
|
||||
DateTime time.Time `db:"datetime"`
|
||||
UUID string `db:"uuid"`
|
||||
}
|
||||
|
||||
table, err := parser.Parse(reflect.ValueOf(new(StructWithSQLType)))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, "struct_with_sql_type", table.Name)
|
||||
assert.EqualValues(t, 5, len(table.Columns()))
|
||||
assert.EqualValues(t, "col1", table.Columns()[0].Name)
|
||||
assert.EqualValues(t, "col2", table.Columns()[1].Name)
|
||||
assert.EqualValues(t, "int", table.Columns()[2].Name)
|
||||
assert.EqualValues(t, "date_time", table.Columns()[3].Name)
|
||||
assert.EqualValues(t, "uuid", table.Columns()[4].Name)
|
||||
|
||||
assert.EqualValues(t, "VARCHAR", table.Columns()[0].SQLType.Name)
|
||||
assert.EqualValues(t, "CHAR", table.Columns()[1].SQLType.Name)
|
||||
assert.EqualValues(t, "BIGINT", table.Columns()[2].SQLType.Name)
|
||||
assert.EqualValues(t, "DATETIME", table.Columns()[3].SQLType.Name)
|
||||
assert.EqualValues(t, "UUID", table.Columns()[4].SQLType.Name)
|
||||
}
|
||||
|
|
148
tags/tag.go
148
tags/tag.go
|
@ -14,30 +14,74 @@ import (
|
|||
"xorm.io/xorm/schemas"
|
||||
)
|
||||
|
||||
func splitTag(tag string) (tags []string) {
|
||||
tag = strings.TrimSpace(tag)
|
||||
var hasQuote = false
|
||||
var lastIdx = 0
|
||||
for i, t := range tag {
|
||||
if t == '\'' {
|
||||
hasQuote = !hasQuote
|
||||
} else if t == ' ' {
|
||||
if lastIdx < i && !hasQuote {
|
||||
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
|
||||
lastIdx = i + 1
|
||||
type tag struct {
|
||||
name string
|
||||
params []string
|
||||
}
|
||||
|
||||
func splitTag(tagStr string) ([]tag, error) {
|
||||
tagStr = strings.TrimSpace(tagStr)
|
||||
var (
|
||||
inQuote bool
|
||||
inBigQuote bool
|
||||
lastIdx int
|
||||
curTag tag
|
||||
paramStart int
|
||||
tags []tag
|
||||
)
|
||||
for i, t := range tagStr {
|
||||
switch t {
|
||||
case '\'':
|
||||
inQuote = !inQuote
|
||||
case ' ':
|
||||
if !inQuote && !inBigQuote {
|
||||
if lastIdx < i {
|
||||
if curTag.name == "" {
|
||||
curTag.name = tagStr[lastIdx:i]
|
||||
}
|
||||
tags = append(tags, curTag)
|
||||
lastIdx = i + 1
|
||||
curTag = tag{}
|
||||
} else if lastIdx == i {
|
||||
lastIdx = i + 1
|
||||
}
|
||||
} else if inBigQuote && !inQuote {
|
||||
paramStart = i + 1
|
||||
}
|
||||
case ',':
|
||||
if !inQuote && !inBigQuote {
|
||||
return nil, fmt.Errorf("comma[%d] of %s should be in quote or big quote", i, tagStr)
|
||||
}
|
||||
if !inQuote && inBigQuote {
|
||||
curTag.params = append(curTag.params, strings.TrimSpace(tagStr[paramStart:i]))
|
||||
paramStart = i + 1
|
||||
}
|
||||
case '(':
|
||||
inBigQuote = true
|
||||
if !inQuote {
|
||||
curTag.name = tagStr[lastIdx:i]
|
||||
paramStart = i + 1
|
||||
}
|
||||
case ')':
|
||||
inBigQuote = false
|
||||
if !inQuote {
|
||||
curTag.params = append(curTag.params, tagStr[paramStart:i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if lastIdx < len(tag) {
|
||||
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
|
||||
if lastIdx < len(tagStr) {
|
||||
if curTag.name == "" {
|
||||
curTag.name = tagStr[lastIdx:]
|
||||
}
|
||||
tags = append(tags, curTag)
|
||||
}
|
||||
return
|
||||
return tags, nil
|
||||
}
|
||||
|
||||
// Context represents a context for xorm tag parse.
|
||||
type Context struct {
|
||||
tagName string
|
||||
params []string
|
||||
tag
|
||||
tagUname string
|
||||
preTag, nextTag string
|
||||
table *schemas.Table
|
||||
col *schemas.Column
|
||||
|
@ -76,6 +120,7 @@ var (
|
|||
"CACHE": CacheTagHandler,
|
||||
"NOCACHE": NoCacheTagHandler,
|
||||
"COMMENT": CommentTagHandler,
|
||||
"EXTENDS": ExtendsTagHandler,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -124,6 +169,7 @@ func NotNullTagHandler(ctx *Context) error {
|
|||
// AutoIncrTagHandler describes autoincr tag handler
|
||||
func AutoIncrTagHandler(ctx *Context) error {
|
||||
ctx.col.IsAutoIncrement = true
|
||||
ctx.col.Nullable = false
|
||||
/*
|
||||
if len(ctx.params) > 0 {
|
||||
autoStartInt, err := strconv.Atoi(ctx.params[0])
|
||||
|
@ -225,41 +271,44 @@ func CommentTagHandler(ctx *Context) error {
|
|||
|
||||
// SQLTypeTagHandler describes SQL Type tag handler
|
||||
func SQLTypeTagHandler(ctx *Context) error {
|
||||
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
|
||||
if strings.EqualFold(ctx.tagName, "JSON") {
|
||||
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagUname}
|
||||
if ctx.tagUname == "JSON" {
|
||||
ctx.col.IsJSON = true
|
||||
}
|
||||
if len(ctx.params) > 0 {
|
||||
if ctx.tagName == schemas.Enum {
|
||||
ctx.col.EnumOptions = make(map[string]int)
|
||||
for k, v := range ctx.params {
|
||||
v = strings.TrimSpace(v)
|
||||
v = strings.Trim(v, "'")
|
||||
ctx.col.EnumOptions[v] = k
|
||||
if len(ctx.params) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch ctx.tagUname {
|
||||
case schemas.Enum:
|
||||
ctx.col.EnumOptions = make(map[string]int)
|
||||
for k, v := range ctx.params {
|
||||
v = strings.TrimSpace(v)
|
||||
v = strings.Trim(v, "'")
|
||||
ctx.col.EnumOptions[v] = k
|
||||
}
|
||||
case schemas.Set:
|
||||
ctx.col.SetOptions = make(map[string]int)
|
||||
for k, v := range ctx.params {
|
||||
v = strings.TrimSpace(v)
|
||||
v = strings.Trim(v, "'")
|
||||
ctx.col.SetOptions[v] = k
|
||||
}
|
||||
default:
|
||||
var err error
|
||||
if len(ctx.params) == 2 {
|
||||
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if ctx.tagName == schemas.Set {
|
||||
ctx.col.SetOptions = make(map[string]int)
|
||||
for k, v := range ctx.params {
|
||||
v = strings.TrimSpace(v)
|
||||
v = strings.Trim(v, "'")
|
||||
ctx.col.SetOptions[v] = k
|
||||
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
if len(ctx.params) == 2 {
|
||||
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if len(ctx.params) == 1 {
|
||||
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if len(ctx.params) == 1 {
|
||||
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -289,11 +338,12 @@ func ExtendsTagHandler(ctx *Context) error {
|
|||
}
|
||||
for _, col := range parentTable.Columns() {
|
||||
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
|
||||
col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...)
|
||||
|
||||
var tagPrefix = ctx.col.FieldName
|
||||
if len(ctx.params) > 0 {
|
||||
col.Nullable = isPtr
|
||||
tagPrefix = ctx.params[0]
|
||||
tagPrefix = strings.Trim(ctx.params[0], "'")
|
||||
if col.IsPrimaryKey {
|
||||
col.Name = ctx.col.FieldName
|
||||
col.IsPrimaryKey = false
|
||||
|
@ -315,7 +365,7 @@ func ExtendsTagHandler(ctx *Context) error {
|
|||
default:
|
||||
//TODO: warning
|
||||
}
|
||||
return nil
|
||||
return ErrIgnoreField
|
||||
}
|
||||
|
||||
// CacheTagHandler describes cache tag handler
|
||||
|
|
|
@ -7,24 +7,83 @@ package tags
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"xorm.io/xorm/internal/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSplitTag(t *testing.T) {
|
||||
var cases = []struct {
|
||||
tag string
|
||||
tags []string
|
||||
tags []tag
|
||||
}{
|
||||
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []string{"not", "null", "default", "'2000-01-01 00:00:00'", "TIMESTAMP"}},
|
||||
{"TEXT", []string{"TEXT"}},
|
||||
{"default('2000-01-01 00:00:00')", []string{"default('2000-01-01 00:00:00')"}},
|
||||
{"json binary", []string{"json", "binary"}},
|
||||
{"not null default '2000-01-01 00:00:00' TIMESTAMP", []tag{
|
||||
{
|
||||
name: "not",
|
||||
},
|
||||
{
|
||||
name: "null",
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
},
|
||||
{
|
||||
name: "'2000-01-01 00:00:00'",
|
||||
},
|
||||
{
|
||||
name: "TIMESTAMP",
|
||||
},
|
||||
},
|
||||
},
|
||||
{"TEXT", []tag{
|
||||
{
|
||||
name: "TEXT",
|
||||
},
|
||||
},
|
||||
},
|
||||
{"default('2000-01-01 00:00:00')", []tag{
|
||||
{
|
||||
name: "default",
|
||||
params: []string{
|
||||
"'2000-01-01 00:00:00'",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"json binary", []tag{
|
||||
{
|
||||
name: "json",
|
||||
},
|
||||
{
|
||||
name: "binary",
|
||||
},
|
||||
},
|
||||
},
|
||||
{"numeric(10, 2)", []tag{
|
||||
{
|
||||
name: "numeric",
|
||||
params: []string{"10", "2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"numeric(10, 2) notnull", []tag{
|
||||
{
|
||||
name: "numeric",
|
||||
params: []string{"10", "2"},
|
||||
},
|
||||
{
|
||||
name: "notnull",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, kase := range cases {
|
||||
tags := splitTag(kase.tag)
|
||||
if !utils.SliceEq(tags, kase.tags) {
|
||||
t.Fatalf("[%d]%v is not equal [%d]%v", len(tags), tags, len(kase.tags), kase.tags)
|
||||
}
|
||||
t.Run(kase.tag, func(t *testing.T) {
|
||||
tags, err := splitTag(kase.tag)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, len(tags), len(kase.tags))
|
||||
for i := 0; i < len(tags); i++ {
|
||||
assert.Equal(t, tags[i], kase.tags[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue