From 7bbbcba21bf01878a8e1ccb9041d1befb88e54ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=95=86=E8=AE=AF=E5=9C=A8=E7=BA=BF?= Date: Mon, 5 May 2014 22:26:17 +0800 Subject: [PATCH] support enum type for mysql --- engine.go | 36 +++++++++++++++++++++++------------- mysql_dialect.go | 39 ++++++++++++++++++++++++++++++--------- session.go | 16 ++++++++-------- 3 files changed, 61 insertions(+), 30 deletions(-) diff --git a/engine.go b/engine.go index 0143e5b9..4581b3e1 100644 --- a/engine.go +++ b/engine.go @@ -666,20 +666,30 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { continue } col.SQLType = core.SQLType{fs[0], 0, 0} - fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") - if len(fs2) == 2 { - col.Length, err = strconv.Atoi(fs2[0]) - if err != nil { - engine.LogError(err) + if fs[0] == core.Enum && fs[1][0] == '\'' { //enum + options := strings.Split(fs[1][0:len(fs[1])-1], ",") + col.EnumOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.EnumOptions[v] = k } - col.Length2, err = strconv.Atoi(fs2[1]) - if err != nil { - engine.LogError(err) - } - } else if len(fs2) == 1 { - col.Length, err = strconv.Atoi(fs2[0]) - if err != nil { - engine.LogError(err) + } else { + fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") + if len(fs2) == 2 { + col.Length, err = strconv.Atoi(fs2[0]) + if err != nil { + engine.LogError(err) + } + col.Length2, err = strconv.Atoi(fs2[1]) + if err != nil { + engine.LogError(err) + } + } else if len(fs2) == 1 { + col.Length, err = strconv.Atoi(fs2[0]) + if err != nil { + engine.LogError(err) + } } } } else { diff --git a/mysql_dialect.go b/mysql_dialect.go index 71273183..66107c47 100644 --- a/mysql_dialect.go +++ b/mysql_dialect.go @@ -53,6 +53,17 @@ func (db *mysql) SqlType(c *core.Column) string { case core.TimeStampz: res = core.Char c.Length = 64 + case core.Enum: //mysql enum + res = core.Enum + res += "(" + for v, k := range c.EnumOptions { + if k > 0 { + res += fmt.Sprintf(",'%v'", v) + } else { + res += fmt.Sprintf("'%v'", v) + } + } + res += ")" default: res = t } @@ -143,23 +154,33 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column } cts := strings.Split(colType, "(") + colName := cts[0] + colType = strings.ToUpper(colName) var len1, len2 int if len(cts) == 2 { idx := strings.Index(cts[1], ")") - lens := strings.Split(cts[1][0:idx], ",") - len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) - if err != nil { - return nil, nil, err - } - if len(lens) == 2 { - len2, err = strconv.Atoi(lens[1]) + if colType == core.Enum && cts[1][0] == '\'' { //enum + options := strings.Split(cts[1][0:idx], ",") + col.EnumOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.EnumOptions[v] = k + } + } else { + lens := strings.Split(cts[1][0:idx], ",") + len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) if err != nil { return nil, nil, err } + if len(lens) == 2 { + len2, err = strconv.Atoi(lens[1]) + if err != nil { + return nil, nil, err + } + } } } - colName := cts[0] - colType = strings.ToUpper(colName) col.Length = len1 col.Length2 = len2 if _, ok := core.SqlTypes[colType]; ok { diff --git a/session.go b/session.go index f312e41b..6c188085 100644 --- a/session.go +++ b/session.go @@ -718,7 +718,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } - // 查询数目太大,采用缓存将不是一个很好的方式。 + // 查询数目太大,采用缓存将不是一个很好的方式〠if len(resultsSlice) > 500 { session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 500, no cache", len(resultsSlice)) return ErrCacheFailed @@ -2484,15 +2484,15 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return fieldValue.String(), nil case reflect.Struct: if fieldType == core.TimeType { - t := fieldValue.Interface().(time.Time) - if session.Engine.dialect.DBType() == core.MSSQL { - if t.IsZero() { - return nil, nil - } - } switch fieldValue.Interface().(type) { case time.Time: - tf := session.Engine.FormatTime(col.SQLType.Name, fieldValue.Interface().(time.Time)) + t := fieldValue.Interface().(time.Time) + if session.Engine.dialect.DBType() == core.MSSQL { + if t.IsZero() { + return nil, nil + } + } + tf := session.Engine.FormatTime(col.SQLType.Name, t) return tf, nil default: return fieldValue.Interface(), nil