diff --git a/engine.go b/engine.go index 0962a08e..6ab7d465 100644 --- a/engine.go +++ b/engine.go @@ -273,7 +273,6 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) { } //table.Columns = cols //table.ColumnsSeq = colSeq - indexes, err := engine.dialect.GetIndexes(table.Name) if err != nil { return nil, err @@ -732,6 +731,14 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { v = strings.Trim(v, "'") col.EnumOptions[v] = k } + } else if fs[0] == core.Set && fs[1][0] == '\'' { //set + options := strings.Split(fs[1][0:len(fs[1])-1], ",") + col.SetOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.SetOptions[v] = k + } } else { fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") if len(fs2) == 2 { diff --git a/mysql_dialect.go b/mysql_dialect.go index 4e430165..b9911138 100644 --- a/mysql_dialect.go +++ b/mysql_dialect.go @@ -64,6 +64,17 @@ func (db *mysql) SqlType(c *core.Column) string { } } res += ")" + case core.Set: //mysql set + res = core.Set + res += "(" + for v, k := range c.SetOptions { + if k > 0 { + res += fmt.Sprintf(",'%v'", v) + } else { + res += fmt.Sprintf("'%v'", v) + } + } + res += ")" default: res = t } @@ -145,6 +156,7 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column if err != nil { return nil, nil, err } + //fmt.Println(columnName, isNullable, colType, colKey, extra, colDefault) col.Name = strings.Trim(columnName, "` ") if "YES" == isNullable { col.Nullable = true @@ -171,6 +183,14 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column v = strings.Trim(v, "'") col.EnumOptions[v] = k } + } else if colType == core.Set && cts[1][0] == '\'' { + options := strings.Split(cts[1][0:idx], ",") + col.SetOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.SetOptions[v] = k + } } else { lens := strings.Split(cts[1][0:idx], ",") len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) @@ -185,6 +205,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column } } } + if colType == "FLOAT UNSIGNED" { + colType = "FLOAT" + } col.Length = len1 col.Length2 = len2 if _, ok := core.SqlTypes[colType]; ok {