support enum type for mysql

This commit is contained in:
商讯在线 2014-05-05 22:26:17 +08:00
parent a190f71d40
commit 7bbbcba21b
3 changed files with 61 additions and 30 deletions

View File

@ -666,20 +666,30 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table {
continue continue
} }
col.SQLType = core.SQLType{fs[0], 0, 0} col.SQLType = core.SQLType{fs[0], 0, 0}
fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") if fs[0] == core.Enum && fs[1][0] == '\'' { //enum
if len(fs2) == 2 { options := strings.Split(fs[1][0:len(fs[1])-1], ",")
col.Length, err = strconv.Atoi(fs2[0]) col.EnumOptions = make(map[string]int)
if err != nil { for k, v := range options {
engine.LogError(err) v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
col.EnumOptions[v] = k
} }
col.Length2, err = strconv.Atoi(fs2[1]) } else {
if err != nil { fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",")
engine.LogError(err) if len(fs2) == 2 {
} col.Length, err = strconv.Atoi(fs2[0])
} else if len(fs2) == 1 { if err != nil {
col.Length, err = strconv.Atoi(fs2[0]) engine.LogError(err)
if err != nil { }
engine.LogError(err) col.Length2, err = strconv.Atoi(fs2[1])
if err != nil {
engine.LogError(err)
}
} else if len(fs2) == 1 {
col.Length, err = strconv.Atoi(fs2[0])
if err != nil {
engine.LogError(err)
}
} }
} }
} else { } else {

View File

@ -53,6 +53,17 @@ func (db *mysql) SqlType(c *core.Column) string {
case core.TimeStampz: case core.TimeStampz:
res = core.Char res = core.Char
c.Length = 64 c.Length = 64
case core.Enum: //mysql enum
res = core.Enum
res += "("
for v, k := range c.EnumOptions {
if k > 0 {
res += fmt.Sprintf(",'%v'", v)
} else {
res += fmt.Sprintf("'%v'", v)
}
}
res += ")"
default: default:
res = t res = t
} }
@ -143,23 +154,33 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column
} }
cts := strings.Split(colType, "(") cts := strings.Split(colType, "(")
colName := cts[0]
colType = strings.ToUpper(colName)
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",") if colType == core.Enum && cts[1][0] == '\'' { //enum
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) options := strings.Split(cts[1][0:idx], ",")
if err != nil { col.EnumOptions = make(map[string]int)
return nil, nil, err for k, v := range options {
} v = strings.TrimSpace(v)
if len(lens) == 2 { v = strings.Trim(v, "'")
len2, err = strconv.Atoi(lens[1]) col.EnumOptions[v] = k
}
} else {
lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1])
if err != nil {
return nil, nil, err
}
}
} }
} }
colName := cts[0]
colType = strings.ToUpper(colName)
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := core.SqlTypes[colType]; ok { if _, ok := core.SqlTypes[colType]; ok {

View File

@ -718,7 +718,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
if err != nil { if err != nil {
return err return err
} }
// 查询数目太大,采用缓存将不是一个很好的方式。 // 查询数目太大,采用缓存将不是一个很好的方式ã€
if len(resultsSlice) > 500 { if len(resultsSlice) > 500 {
session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 500, no cache", len(resultsSlice)) session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 500, no cache", len(resultsSlice))
return ErrCacheFailed return ErrCacheFailed
@ -2484,15 +2484,15 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val
return fieldValue.String(), nil return fieldValue.String(), nil
case reflect.Struct: case reflect.Struct:
if fieldType == core.TimeType { if fieldType == core.TimeType {
t := fieldValue.Interface().(time.Time)
if session.Engine.dialect.DBType() == core.MSSQL {
if t.IsZero() {
return nil, nil
}
}
switch fieldValue.Interface().(type) { switch fieldValue.Interface().(type) {
case time.Time: case time.Time:
tf := session.Engine.FormatTime(col.SQLType.Name, fieldValue.Interface().(time.Time)) t := fieldValue.Interface().(time.Time)
if session.Engine.dialect.DBType() == core.MSSQL {
if t.IsZero() {
return nil, nil
}
}
tf := session.Engine.FormatTime(col.SQLType.Name, t)
return tf, nil return tf, nil
default: default:
return fieldValue.Interface(), nil return fieldValue.Interface(), nil