Merge pull request #55 from admpub/patch-1

fixbug: xorm reverse; miss empty string for default value
This commit is contained in:
lunny 2014-02-08 11:55:12 +08:00
commit e91ca94921
9 changed files with 290 additions and 223 deletions

View File

@ -484,7 +484,8 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
var indexType int var indexType int
var indexName string var indexName string
var preKey string var preKey string
for j, key := range tags { for j,ln := 0,len(tags); j < ln; j++ {
key := tags[j]
k := strings.ToUpper(key) k := strings.ToUpper(key)
switch { switch {
case k == "<-": case k == "<-":
@ -535,7 +536,18 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
if preKey != "DEFAULT" { if preKey != "DEFAULT" {
col.Name = key[1 : len(key)-1] col.Name = key[1 : len(key)-1]
} }
} else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { } else if strings.Contains(k, "(") && (strings.HasSuffix(k, ")") || strings.HasSuffix(k, ",")) {
//[SWH|+]
if strings.HasSuffix(k, ",") {
j++
for j < ln {
k += tags[j]
if strings.HasSuffix(tags[j], ")") {
break
}
j++
}
}
fs := strings.Split(k, "(") fs := strings.Split(k, "(")
if _, ok := sqlTypes[fs[0]]; !ok { if _, ok := sqlTypes[fs[0]]; !ok {
preKey = k preKey = k
@ -611,9 +623,24 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
} }
} else { } else {
sqlType := Type2SQLType(fieldType) sqlType := Type2SQLType(fieldType)
col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, col = &Column{
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, Name: engine.columnMapper.Obj2Table(t.Field(i).Name),
TWOSIDES, false, false, false, false} FieldName: t.Field(i).Name,
SQLType: sqlType,
Length: sqlType.DefaultLength,
Length2: sqlType.DefaultLength2,
Nullable: true,
Default: "",
Indexes: make(map[string]bool),
IsPrimaryKey: false,
IsAutoIncrement:false,
MapType: TWOSIDES,
IsCreated: false,
IsUpdated: false,
IsCascade: false,
IsVersion: false,
DefaultIsEmpty: false,
}
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
col.Nullable = false col.Nullable = false

View File

@ -187,6 +187,10 @@ where a.object_id=object_id('` + tableName + `')`
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
}else{
if col.DefaultIsEmpty {
col.Default = "''"
}
} }
} }
cols[col.Name] = col cols[col.Name] = col

View File

@ -212,6 +212,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
case "COLUMN_DEFAULT": case "COLUMN_DEFAULT":
// add '' // add ''
col.Default = string(content) col.Default = string(content)
if col.Default == "" {
col.DefaultIsEmpty = true
}
case "COLUMN_TYPE": case "COLUMN_TYPE":
cts := strings.Split(string(content), "(") cts := strings.Split(string(content), "(")
var len1, len2 int var len1, len2 int
@ -256,6 +259,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
}else{
if col.DefaultIsEmpty {
col.Default = "''"
}
} }
} }
cols[col.Name] = col cols[col.Name] = col

View File

@ -139,6 +139,9 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
col.Name = strings.Trim(string(content), `" `) col.Name = strings.Trim(string(content), `" `)
case "data_default": case "data_default":
col.Default = string(content) col.Default = string(content)
if col.Default == "" {
col.DefaultIsEmpty = true
}
case "nullable": case "nullable":
if string(content) == "Y" { if string(content) == "Y" {
col.Nullable = true col.Nullable = true
@ -171,6 +174,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
}else{
if col.DefaultIsEmpty {
col.Default = "''"
}
} }
} }
cols[col.Name] = col cols[col.Name] = col

View File

@ -177,6 +177,9 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
col.IsPrimaryKey = true col.IsPrimaryKey = true
} else { } else {
col.Default = string(content) col.Default = string(content)
if col.Default == "" {
col.DefaultIsEmpty = true
}
} }
case "is_nullable": case "is_nullable":
if string(content) == "YES" { if string(content) == "YES" {
@ -218,6 +221,10 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
}else{
if col.DefaultIsEmpty {
col.Default = "''"
}
} }
} }
cols[col.Name] = col cols[col.Name] = col

View File

@ -275,6 +275,7 @@ type Column struct {
IsUpdated bool IsUpdated bool
IsCascade bool IsCascade bool
IsVersion bool IsVersion bool
DefaultIsEmpty bool
} }
// generate column description string according dialect // generate column description string according dialect

View File

@ -241,7 +241,7 @@ func tag(table *xorm.Table, col *xorm.Column) string {
nstr := col.SQLType.Name nstr := col.SQLType.Name
if col.Length != 0 { if col.Length != 0 {
if col.Length2 != 0 { if col.Length2 != 0 {
nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2) nstr += fmt.Sprintf("(%v,%v)", col.Length, col.Length2)
} else { } else {
nstr += fmt.Sprintf("(%v)", col.Length) nstr += fmt.Sprintf("(%v)", col.Length)
} }

View File

@ -1,26 +1,27 @@
package main package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
"github.com/dvirsky/go-pylog/logging" "github.com/dvirsky/go-pylog/logging"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"text/template" "strings" //[SWH|+]
"text/template"
) )
var CmdReverse = &Command{ var CmdReverse = &Command{
UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]", UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]",
Short: "reverse a db to codes", Short: "reverse a db to codes",
Long: ` Long: `
according database's tables and columns to generate codes for Go, C++ and etc. according database's tables and columns to generate codes for Go, C++ and etc.
-m Generated one go file for every table -m Generated one go file for every table
@ -33,236 +34,248 @@ according database's tables and columns to generate codes for Go, C++ and etc.
} }
func init() { func init() {
CmdReverse.Run = runReverse CmdReverse.Run = runReverse
CmdReverse.Flags = map[string]bool{ CmdReverse.Flags = map[string]bool{
"-s": false, "-s": false,
"-l": false, "-l": false,
} }
} }
var ( var (
genJson bool = false genJson bool = false
) )
func printReversePrompt(flag string) { func printReversePrompt(flag string) {
} }
type Tmpl struct { type Tmpl struct {
Tables []*xorm.Table Tables []*xorm.Table
Imports map[string]string Imports map[string]string
Model string Model string
} }
func dirExists(dir string) bool { func dirExists(dir string) bool {
d, e := os.Stat(dir) d, e := os.Stat(dir)
switch { switch {
case e != nil: case e != nil:
return false return false
case !d.IsDir(): case !d.IsDir():
return false return false
} }
return true return true
} }
func runReverse(cmd *Command, args []string) { func runReverse(cmd *Command, args []string) {
num := checkFlags(cmd.Flags, args, printReversePrompt) num := checkFlags(cmd.Flags, args, printReversePrompt)
if num == -1 { if num == -1 {
return return
} }
args = args[num:] args = args[num:]
if len(args) < 3 { if len(args) < 3 {
fmt.Println("params error, please see xorm help reverse") fmt.Println("params error, please see xorm help reverse")
return return
} }
var isMultiFile bool = true var isMultiFile bool = true
if use, ok := cmd.Flags["-s"]; ok { if use, ok := cmd.Flags["-s"]; ok {
isMultiFile = !use isMultiFile = !use
} }
curPath, err := os.Getwd() curPath, err := os.Getwd()
if err != nil { if err != nil {
fmt.Println(curPath) fmt.Println(curPath)
return return
} }
var genDir string var genDir string
var model string var model string
if len(args) == 4 { if len(args) == 4 {
genDir, err = filepath.Abs(args[3]) genDir, err = filepath.Abs(args[3])
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
model = path.Base(genDir) //[SWH|+] 经测试path.Base不能解析windows下的“\”,需要替换为“/”
} else { genDir = strings.Replace(genDir, "\\", "/", -1)
model = "model" model = path.Base(genDir)
genDir = path.Join(curPath, model) } else {
} model = "model"
genDir = path.Join(curPath, model)
}
dir, err := filepath.Abs(args[2]) dir, err := filepath.Abs(args[2])
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
if !dirExists(dir) { if !dirExists(dir) {
logging.Error("Template %v path is not exist", dir) logging.Error("Template %v path is not exist", dir)
return return
} }
var langTmpl LangTmpl var langTmpl LangTmpl
var ok bool var ok bool
var lang string = "go" var lang string = "go"
var prefix string = "" //[SWH|+]
cfgPath := path.Join(dir, "config")
info, err := os.Stat(cfgPath)
var configs map[string]string
if err == nil && !info.IsDir() {
configs = loadConfig(cfgPath)
if l, ok := configs["lang"]; ok {
lang = l
}
if j, ok := configs["genJson"]; ok {
genJson, err = strconv.ParseBool(j)
}
//[SWH|+]
if j, ok := configs["prefix"]; ok {
prefix = j
}
}
cfgPath := path.Join(dir, "config") if langTmpl, ok = langTmpls[lang]; !ok {
info, err := os.Stat(cfgPath) fmt.Println("Unsupported programing language", lang)
var configs map[string]string return
if err == nil && !info.IsDir() { }
configs = loadConfig(cfgPath)
if l, ok := configs["lang"]; ok {
lang = l
}
if j, ok := configs["genJson"]; ok {
genJson, err = strconv.ParseBool(j)
}
}
if langTmpl, ok = langTmpls[lang]; !ok { os.MkdirAll(genDir, os.ModePerm)
fmt.Println("Unsupported programing language", lang)
return
}
os.MkdirAll(genDir, os.ModePerm) Orm, err := xorm.NewEngine(args[0], args[1])
if err != nil {
logging.Error("%v", err)
return
}
Orm, err := xorm.NewEngine(args[0], args[1]) tables, err := Orm.DBMetas()
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
tables, err := Orm.DBMetas() filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
if err != nil { if info.IsDir() {
logging.Error("%v", err) return nil
return }
}
filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { if info.Name() == "config" {
if info.IsDir() { return nil
return nil }
}
if info.Name() == "config" { bs, err := ioutil.ReadFile(f)
return nil if err != nil {
} logging.Error("%v", err)
return err
}
bs, err := ioutil.ReadFile(f) t := template.New(f)
if err != nil { t.Funcs(langTmpl.Funcs)
logging.Error("%v", err)
return err
}
t := template.New(f) tmpl, err := t.Parse(string(bs))
t.Funcs(langTmpl.Funcs) if err != nil {
logging.Error("%v", err)
return err
}
tmpl, err := t.Parse(string(bs)) var w *os.File
if err != nil { fileName := info.Name()
logging.Error("%v", err) newFileName := fileName[:len(fileName)-4]
return err ext := path.Ext(newFileName)
}
var w *os.File if !isMultiFile {
fileName := info.Name() w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600)
newFileName := fileName[:len(fileName)-4] if err != nil {
ext := path.Ext(newFileName) logging.Error("%v", err)
return err
}
if !isMultiFile { imports := langTmpl.GenImports(tables)
w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600) tbls := make([]*xorm.Table, 0)
if err != nil { for _, table := range tables {
logging.Error("%v", err) //[SWH|+]
return err if prefix != "" {
} table.Name = strings.TrimPrefix(table.Name, prefix)
}
tbls = append(tbls, table)
}
imports := langTmpl.GenImports(tables) newbytes := bytes.NewBufferString("")
tbls := make([]*xorm.Table, 0) t := &Tmpl{Tables: tbls, Imports: imports, Model: model}
for _, table := range tables { err = tmpl.Execute(newbytes, t)
tbls = append(tbls, table) if err != nil {
} logging.Error("%v", err)
return err
}
newbytes := bytes.NewBufferString("") tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil {
logging.Error("%v", err)
return err
}
var source string
if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent))
if err != nil {
logging.Error("%v", err)
return err
}
} else {
source = string(tplcontent)
}
t := &Tmpl{Tables: tbls, Imports: imports, Model: model} w.WriteString(source)
err = tmpl.Execute(newbytes, t) w.Close()
if err != nil { } else {
logging.Error("%v", err) for _, table := range tables {
return err //[SWH|+]
} if prefix != "" {
table.Name = strings.TrimPrefix(table.Name, prefix)
}
// imports
tbs := []*xorm.Table{table}
imports := langTmpl.GenImports(tbs)
w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
logging.Error("%v", err)
return err
}
tplcontent, err := ioutil.ReadAll(newbytes) newbytes := bytes.NewBufferString("")
if err != nil {
logging.Error("%v", err)
return err
}
var source string
if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent))
if err != nil {
logging.Error("%v", err)
return err
}
} else {
source = string(tplcontent)
}
w.WriteString(source) t := &Tmpl{Tables: tbs, Imports: imports, Model: model}
w.Close() err = tmpl.Execute(newbytes, t)
} else { if err != nil {
for _, table := range tables { logging.Error("%v", err)
// imports return err
tbs := []*xorm.Table{table} }
imports := langTmpl.GenImports(tbs)
w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var source string
if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent))
if err != nil {
logging.Error("%v-%v", err, string(tplcontent))
return err
}
} else {
source = string(tplcontent)
}
newbytes := bytes.NewBufferString("") w.WriteString(source)
w.Close()
}
}
t := &Tmpl{Tables: tbs, Imports: imports, Model: model} return nil
err = tmpl.Execute(newbytes, t) })
if err != nil {
logging.Error("%v", err)
return err
}
tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil {
logging.Error("%v", err)
return err
}
var source string
if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent))
if err != nil {
logging.Error("%v-%v", err, string(tplcontent))
return err
}
} else {
source = string(tplcontent)
}
w.WriteString(source)
w.Close()
}
}
return nil
})
} }

View File

@ -1,2 +1,3 @@
lang=go lang=go
genJson=0 genJson=0
prefix=cos_