diff --git a/engine.go b/engine.go index 1f2312de..71ef63b4 100644 --- a/engine.go +++ b/engine.go @@ -611,9 +611,24 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } } else { sqlType := Type2SQLType(fieldType) - col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, - sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, - TWOSIDES, false, false, false, false} + col = &Column{ + Name: engine.columnMapper.Obj2Table(t.Field(i).Name), + FieldName: t.Field(i).Name, + SQLType: sqlType, + Length: sqlType.DefaultLength, + Length2: sqlType.DefaultLength2, + Nullable: true, + Default: "", + Indexes: make(map[string]bool), + IsPrimaryKey: false, + IsAutoIncrement:false, + MapType: TWOSIDES, + IsCreated: false, + IsUpdated: false, + IsCascade: false, + IsVersion: false, + DefaultIsEmpty: false, + } } if col.IsAutoIncrement { col.Nullable = false diff --git a/mssql.go b/mssql.go index 3606332f..6e9776d2 100644 --- a/mssql.go +++ b/mssql.go @@ -136,8 +136,8 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mssql) GetColumns(tableName string) ([]string, map[string]*Column, error) { args := []interface{}{} - s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale -from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id + s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale +from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id where a.object_id=object_id('` + tableName + `')` cnn, err := sql.Open(db.driverName, db.dataSourceName) if err != nil { @@ -187,6 +187,10 @@ where a.object_id=object_id('` + tableName + `')` if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" + }else{ + if col.DefaultIsEmpty { + col.Default = "''" + } } } cols[col.Name] = col @@ -224,18 +228,18 @@ func (db *mssql) GetTables() ([]*Table, error) { func (db *mssql) GetIndexes(tableName string) (map[string]*Index, error) { args := []interface{}{tableName} - s := `SELECT -IXS.NAME AS [INDEX_NAME], -C.NAME AS [COLUMN_NAME], -IXS.is_unique AS [IS_UNIQUE], -CASE IXCS.IS_INCLUDED_COLUMN -WHEN 0 THEN 'NONE' -ELSE 'INCLUDED' END AS [IS_INCLUDED_COLUMN] -FROM SYS.INDEXES IXS -INNER JOIN SYS.INDEX_COLUMNS IXCS -ON IXS.OBJECT_ID=IXCS.OBJECT_ID AND IXS.INDEX_ID = IXCS.INDEX_ID -INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID -AND IXCS.COLUMN_ID=C.COLUMN_ID + s := `SELECT +IXS.NAME AS [INDEX_NAME], +C.NAME AS [COLUMN_NAME], +IXS.is_unique AS [IS_UNIQUE], +CASE IXCS.IS_INCLUDED_COLUMN +WHEN 0 THEN 'NONE' +ELSE 'INCLUDED' END AS [IS_INCLUDED_COLUMN] +FROM SYS.INDEXES IXS +INNER JOIN SYS.INDEX_COLUMNS IXCS +ON IXS.OBJECT_ID=IXCS.OBJECT_ID AND IXS.INDEX_ID = IXCS.INDEX_ID +INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID +AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` cnn, err := sql.Open(db.driverName, db.dataSourceName) diff --git a/mysql.go b/mysql.go index 4dcde839..aff13333 100644 --- a/mysql.go +++ b/mysql.go @@ -212,6 +212,9 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err case "COLUMN_DEFAULT": // add '' col.Default = string(content) + if col.Default == "" { + col.DefaultIsEmpty = true + } case "COLUMN_TYPE": cts := strings.Split(string(content), "(") var len1, len2 int @@ -256,6 +259,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" + }else{ + if col.DefaultIsEmpty { + col.Default = "''" + } } } cols[col.Name] = col diff --git a/oracle.go b/oracle.go index 4e3c6fb6..0b4238ca 100644 --- a/oracle.go +++ b/oracle.go @@ -139,6 +139,9 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er col.Name = strings.Trim(string(content), `" `) case "data_default": col.Default = string(content) + if col.Default == "" { + col.DefaultIsEmpty = true + } case "nullable": if string(content) == "Y" { col.Nullable = true @@ -171,6 +174,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" + }else{ + if col.DefaultIsEmpty { + col.Default = "''" + } } } cols[col.Name] = col diff --git a/postgres.go b/postgres.go index 97550543..4c7f97e2 100644 --- a/postgres.go +++ b/postgres.go @@ -177,6 +177,9 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, col.IsPrimaryKey = true } else { col.Default = string(content) + if col.Default == "" { + col.DefaultIsEmpty = true + } } case "is_nullable": if string(content) == "YES" { @@ -218,6 +221,10 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" + }else{ + if col.DefaultIsEmpty { + col.Default = "''" + } } } cols[col.Name] = col diff --git a/table.go b/table.go index 76b4c3ae..34cb0862 100644 --- a/table.go +++ b/table.go @@ -275,6 +275,7 @@ type Column struct { IsUpdated bool IsCascade bool IsVersion bool + DefaultIsEmpty bool } // generate column description string according dialect diff --git a/xorm/go.go b/xorm/go.go index 682f6b0b..533ce026 100644 --- a/xorm/go.go +++ b/xorm/go.go @@ -241,7 +241,7 @@ func tag(table *xorm.Table, col *xorm.Column) string { nstr := col.SQLType.Name if col.Length != 0 { if col.Length2 != 0 { - nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2) + nstr += fmt.Sprintf("(%v,%v)", col.Length, col.Length2) } else { nstr += fmt.Sprintf("(%v)", col.Length) } diff --git a/xorm/reverse.go b/xorm/reverse.go index 17accfe5..7eab1980 100644 --- a/xorm/reverse.go +++ b/xorm/reverse.go @@ -1,26 +1,27 @@ package main import ( - "bytes" - "fmt" - _ "github.com/bylevel/pq" - "github.com/dvirsky/go-pylog/logging" - _ "github.com/go-sql-driver/mysql" - "github.com/lunny/xorm" - _ "github.com/mattn/go-sqlite3" - _ "github.com/ziutek/mymysql/godrv" - "io/ioutil" - "os" - "path" - "path/filepath" - "strconv" - "text/template" + "bytes" + "fmt" + _ "github.com/bylevel/pq" + "github.com/dvirsky/go-pylog/logging" + _ "github.com/go-sql-driver/mysql" + "github.com/lunny/xorm" + _ "github.com/mattn/go-sqlite3" + _ "github.com/ziutek/mymysql/godrv" + "io/ioutil" + "os" + "path" + "path/filepath" + "strconv" + "strings" //[SWH|+] + "text/template" ) var CmdReverse = &Command{ - UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]", - Short: "reverse a db to codes", - Long: ` + UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]", + Short: "reverse a db to codes", + Long: ` according database's tables and columns to generate codes for Go, C++ and etc. -m Generated one go file for every table @@ -33,236 +34,248 @@ according database's tables and columns to generate codes for Go, C++ and etc. } func init() { - CmdReverse.Run = runReverse - CmdReverse.Flags = map[string]bool{ - "-s": false, - "-l": false, - } + CmdReverse.Run = runReverse + CmdReverse.Flags = map[string]bool{ + "-s": false, + "-l": false, + } } var ( - genJson bool = false + genJson bool = false ) func printReversePrompt(flag string) { } type Tmpl struct { - Tables []*xorm.Table - Imports map[string]string - Model string + Tables []*xorm.Table + Imports map[string]string + Model string } func dirExists(dir string) bool { - d, e := os.Stat(dir) - switch { - case e != nil: - return false - case !d.IsDir(): - return false - } + d, e := os.Stat(dir) + switch { + case e != nil: + return false + case !d.IsDir(): + return false + } - return true + return true } func runReverse(cmd *Command, args []string) { - num := checkFlags(cmd.Flags, args, printReversePrompt) - if num == -1 { - return - } - args = args[num:] + num := checkFlags(cmd.Flags, args, printReversePrompt) + if num == -1 { + return + } + args = args[num:] - if len(args) < 3 { - fmt.Println("params error, please see xorm help reverse") - return - } + if len(args) < 3 { + fmt.Println("params error, please see xorm help reverse") + return + } - var isMultiFile bool = true - if use, ok := cmd.Flags["-s"]; ok { - isMultiFile = !use - } + var isMultiFile bool = true + if use, ok := cmd.Flags["-s"]; ok { + isMultiFile = !use + } - curPath, err := os.Getwd() - if err != nil { - fmt.Println(curPath) - return - } + curPath, err := os.Getwd() + if err != nil { + fmt.Println(curPath) + return + } - var genDir string - var model string - if len(args) == 4 { + var genDir string + var model string + if len(args) == 4 { - genDir, err = filepath.Abs(args[3]) - if err != nil { - fmt.Println(err) - return - } - model = path.Base(genDir) - } else { - model = "model" - genDir = path.Join(curPath, model) - } + genDir, err = filepath.Abs(args[3]) + if err != nil { + fmt.Println(err) + return + } + //[SWH|+] 经测试,path.Base不能解析windows下的“\”,需要替换为“/” + genDir = strings.Replace(genDir, "\\", "/", -1) + model = path.Base(genDir) + } else { + model = "model" + genDir = path.Join(curPath, model) + } - dir, err := filepath.Abs(args[2]) - if err != nil { - logging.Error("%v", err) - return - } + dir, err := filepath.Abs(args[2]) + if err != nil { + logging.Error("%v", err) + return + } - if !dirExists(dir) { - logging.Error("Template %v path is not exist", dir) - return - } + if !dirExists(dir) { + logging.Error("Template %v path is not exist", dir) + return + } - var langTmpl LangTmpl - var ok bool - var lang string = "go" + var langTmpl LangTmpl + var ok bool + var lang string = "go" + var prefix string = "" //[SWH|+] + cfgPath := path.Join(dir, "config") + info, err := os.Stat(cfgPath) + var configs map[string]string + if err == nil && !info.IsDir() { + configs = loadConfig(cfgPath) + if l, ok := configs["lang"]; ok { + lang = l + } + if j, ok := configs["genJson"]; ok { + genJson, err = strconv.ParseBool(j) + } + //[SWH|+] + if j, ok := configs["prefix"]; ok { + prefix = j + } + } - cfgPath := path.Join(dir, "config") - info, err := os.Stat(cfgPath) - var configs map[string]string - if err == nil && !info.IsDir() { - configs = loadConfig(cfgPath) - if l, ok := configs["lang"]; ok { - lang = l - } - if j, ok := configs["genJson"]; ok { - genJson, err = strconv.ParseBool(j) - } - } + if langTmpl, ok = langTmpls[lang]; !ok { + fmt.Println("Unsupported programing language", lang) + return + } - if langTmpl, ok = langTmpls[lang]; !ok { - fmt.Println("Unsupported programing language", lang) - return - } + os.MkdirAll(genDir, os.ModePerm) - os.MkdirAll(genDir, os.ModePerm) + Orm, err := xorm.NewEngine(args[0], args[1]) + if err != nil { + logging.Error("%v", err) + return + } - Orm, err := xorm.NewEngine(args[0], args[1]) - if err != nil { - logging.Error("%v", err) - return - } + tables, err := Orm.DBMetas() + if err != nil { + logging.Error("%v", err) + return + } - tables, err := Orm.DBMetas() - if err != nil { - logging.Error("%v", err) - return - } + filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } - filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { - if info.IsDir() { - return nil - } + if info.Name() == "config" { + return nil + } - if info.Name() == "config" { - return nil - } + bs, err := ioutil.ReadFile(f) + if err != nil { + logging.Error("%v", err) + return err + } - bs, err := ioutil.ReadFile(f) - if err != nil { - logging.Error("%v", err) - return err - } + t := template.New(f) + t.Funcs(langTmpl.Funcs) - t := template.New(f) - t.Funcs(langTmpl.Funcs) + tmpl, err := t.Parse(string(bs)) + if err != nil { + logging.Error("%v", err) + return err + } - tmpl, err := t.Parse(string(bs)) - if err != nil { - logging.Error("%v", err) - return err - } + var w *os.File + fileName := info.Name() + newFileName := fileName[:len(fileName)-4] + ext := path.Ext(newFileName) - var w *os.File - fileName := info.Name() - newFileName := fileName[:len(fileName)-4] - ext := path.Ext(newFileName) + if !isMultiFile { + w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + logging.Error("%v", err) + return err + } - if !isMultiFile { - w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600) - if err != nil { - logging.Error("%v", err) - return err - } + imports := langTmpl.GenImports(tables) + tbls := make([]*xorm.Table, 0) + for _, table := range tables { + //[SWH|+] + if prefix != "" { + table.Name = strings.TrimPrefix(table.Name, prefix) + } + tbls = append(tbls, table) + } - imports := langTmpl.GenImports(tables) + newbytes := bytes.NewBufferString("") - tbls := make([]*xorm.Table, 0) - for _, table := range tables { - tbls = append(tbls, table) - } + t := &Tmpl{Tables: tbls, Imports: imports, Model: model} + err = tmpl.Execute(newbytes, t) + if err != nil { + logging.Error("%v", err) + return err + } - newbytes := bytes.NewBufferString("") + tplcontent, err := ioutil.ReadAll(newbytes) + if err != nil { + logging.Error("%v", err) + return err + } + var source string + if langTmpl.Formater != nil { + source, err = langTmpl.Formater(string(tplcontent)) + if err != nil { + logging.Error("%v", err) + return err + } + } else { + source = string(tplcontent) + } - t := &Tmpl{Tables: tbls, Imports: imports, Model: model} - err = tmpl.Execute(newbytes, t) - if err != nil { - logging.Error("%v", err) - return err - } + w.WriteString(source) + w.Close() + } else { + for _, table := range tables { + //[SWH|+] + if prefix != "" { + table.Name = strings.TrimPrefix(table.Name, prefix) + } + // imports + tbs := []*xorm.Table{table} + imports := langTmpl.GenImports(tbs) + w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + logging.Error("%v", err) + return err + } - tplcontent, err := ioutil.ReadAll(newbytes) - if err != nil { - logging.Error("%v", err) - return err - } - var source string - if langTmpl.Formater != nil { - source, err = langTmpl.Formater(string(tplcontent)) - if err != nil { - logging.Error("%v", err) - return err - } - } else { - source = string(tplcontent) - } + newbytes := bytes.NewBufferString("") - w.WriteString(source) - w.Close() - } else { - for _, table := range tables { - // imports - tbs := []*xorm.Table{table} - imports := langTmpl.GenImports(tbs) + t := &Tmpl{Tables: tbs, Imports: imports, Model: model} + err = tmpl.Execute(newbytes, t) + if err != nil { + logging.Error("%v", err) + return err + } - w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) - if err != nil { - logging.Error("%v", err) - return err - } + tplcontent, err := ioutil.ReadAll(newbytes) + if err != nil { + logging.Error("%v", err) + return err + } + var source string + if langTmpl.Formater != nil { + source, err = langTmpl.Formater(string(tplcontent)) + if err != nil { + logging.Error("%v-%v", err, string(tplcontent)) + return err + } + } else { + source = string(tplcontent) + } - newbytes := bytes.NewBufferString("") + w.WriteString(source) + w.Close() + } + } - t := &Tmpl{Tables: tbs, Imports: imports, Model: model} - err = tmpl.Execute(newbytes, t) - if err != nil { - logging.Error("%v", err) - return err - } - - tplcontent, err := ioutil.ReadAll(newbytes) - if err != nil { - logging.Error("%v", err) - return err - } - var source string - if langTmpl.Formater != nil { - source, err = langTmpl.Formater(string(tplcontent)) - if err != nil { - logging.Error("%v-%v", err, string(tplcontent)) - return err - } - } else { - source = string(tplcontent) - } - - w.WriteString(source) - w.Close() - } - } - - return nil - }) + return nil + }) } diff --git a/xorm/templates/goxorm/config b/xorm/templates/goxorm/config index e99ad029..5d7bf321 100644 --- a/xorm/templates/goxorm/config +++ b/xorm/templates/goxorm/config @@ -1,2 +1,3 @@ lang=go genJson=0 +prefix=cos_