added xorm reverse tool

This commit is contained in:
Lunny Xiao 2013-10-13 23:57:57 +08:00
parent 2caed88b82
commit 42b4dbba03
9 changed files with 405 additions and 65 deletions

View File

@ -200,6 +200,16 @@ func (engine *Engine) DBMetas() ([]*Table, error) {
return nil, err return nil, err
} }
table.Indexes = indexes table.Indexes = indexes
for _, index := range indexes {
for _, name := range index.Cols {
if col, ok := table.Columns[name]; ok {
col.Indexes[index.Name] = true
} else {
return nil, errors.New("Unkonwn col " + name + " in indexes")
}
}
}
} }
return tables, nil return tables, nil
} }

View File

@ -155,6 +155,7 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -162,10 +163,11 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) {
cols := make(map[string]*Column) cols := make(map[string]*Column)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "COLUMN_NAME": case "COLUMN_NAME":
col.Name = string(content) col.Name = strings.Trim(string(content), "` ")
case "IS_NULLABLE": case "IS_NULLABLE":
if "YES" == string(content) { if "YES" == string(content) {
col.Nullable = true col.Nullable = true
@ -225,6 +227,7 @@ func (db *mysql) GetTables() ([]*Table, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -236,7 +239,7 @@ func (db *mysql) GetTables() ([]*Table, error) {
for name, content := range record { for name, content := range record {
switch name { switch name {
case "TABLE_NAME": case "TABLE_NAME":
table.Name = string(content) table.Name = strings.Trim(string(content), "` ")
case "ENGINE": case "ENGINE":
} }
} }
@ -252,6 +255,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -272,7 +276,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
case "INDEX_NAME": case "INDEX_NAME":
indexName = string(content) indexName = string(content)
case "COLUMN_NAME": case "COLUMN_NAME":
colName = string(content) colName = strings.Trim(string(content), "` ")
} }
} }
if indexName == "PRIMARY" { if indexName == "PRIMARY" {

View File

@ -1,6 +1,7 @@
package xorm package xorm
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -141,13 +142,14 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa
func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT COLUMN_NAME, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
" FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -155,26 +157,128 @@ func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) {
cols := make(map[string]*Column) cols := make(map[string]*Column)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "COLUMN_NAME": case "column_name":
col.Name = string(content) col.Name = strings.Trim(string(content), `" `)
case "column_default": case "column_default":
if strings.HasPrefix(string(content), "") { if strings.HasPrefix(string(content), "nextval") {
col.IsPrimaryKey col.IsPrimaryKey = true
} }
case "is_nullable":
if string(content) == "YES" {
col.Nullable = true
} else {
col.Nullable = false
}
case "data_type":
ct := string(content)
switch ct {
case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0}
case "double precision":
col.SQLType = SQLType{Double, 0, 0}
case "boolean":
col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone":
col.SQLType = SQLType{Time, 0, 0}
default:
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
}
if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
}
case "character_maximum_length":
i, err := strconv.Atoi(string(content))
if err != nil {
return nil, errors.New("retrieve length error")
}
col.Length = i
case "numeric_precision":
case "numeric_precision_radix":
} }
} }
cols[col.Name] = col
} }
return nil, ErrNotImplemented return cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
return nil, ErrNotImplemented args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
tables := make([]*Table, 0)
for _, record := range res {
table := new(Table)
for name, content := range record {
switch name {
case "tablename":
table.Name = string(content)
}
}
tables = append(tables, table)
}
return tables, nil
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, ErrNotImplemented args := []interface{}{tableName}
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
indexes := make(map[string]*Index, 0)
for _, record := range res {
var indexType int
var indexName string
var colNames []string
for name, content := range record {
switch name {
case "indexname":
indexName = strings.Trim(string(content), `" `)
case "indexdef":
c := string(content)
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
indexType = UniqueType
} else {
indexType = IndexType
}
cs := strings.Split(c, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
}
}
if strings.HasSuffix(indexName, "_pkey") {
continue
}
indexName = indexName[5+len(tableName) : len(indexName)]
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
}
indexes[index.Name] = index
}
return indexes, nil
} }

View File

@ -1,5 +1,11 @@
package xorm package xorm
import (
"database/sql"
"fmt"
"strings"
)
type sqlite3 struct { type sqlite3 struct {
base base
} }
@ -69,24 +75,141 @@ func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ? and sql like '%`" + colName + "`%'", args fmt.Println(tableName, colName)
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
fmt.Println(sql)
return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) {
/*args := []interface{}{db.dbname, tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
SELECT sql FROM sqlite_master WHERE type='table' and name = 'category'; var sql string
sql := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + for _, record := range res {
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" for name, content := range record {
if name == "sql" {
sql = string(content)
}
}
}
return sql, args*/ nStart := strings.Index(sql, "(")
return nil, ErrNotImplemented nEnd := strings.Index(sql, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",")
cols := make(map[string]*Column)
for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column)
col.Indexes = make(map[string]bool)
for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(field, "`[] ")
continue
} else if idx == 1 {
col.SQLType = SQLType{field, 0, 0}
}
switch field {
case "PRIMARY":
col.IsPrimaryKey = true
case "AUTOINCREMENT":
col.IsAutoIncrement = true
case "NULL":
if fields[idx-1] == "NOT" {
col.Nullable = false
} else {
col.Nullable = true
}
}
}
cols[col.Name] = col
}
return cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*Table, error) {
return nil, ErrNotImplemented args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
tables := make([]*Table, 0)
for _, record := range res {
table := new(Table)
for name, content := range record {
switch name {
case "name":
table.Name = string(content)
}
}
if table.Name == "sqlite_sequence" {
continue
}
tables = append(tables, table)
}
return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, ErrNotImplemented args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
indexes := make(map[string]*Index, 0)
for _, record := range res {
var sql string
index := new(Index)
for name, content := range record {
if name == "sql" {
sql = string(content)
}
}
nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
index.Name = indexName[5+len(tableName) : len(indexName)]
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType
} else {
index.Type = IndexType
}
nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")")
colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
index.Cols = make([]string, 0)
for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []"))
}
}
return indexes, nil
} }

View File

@ -48,3 +48,31 @@ func (c *Command) Usage() {
func (c *Command) Runnable() bool { func (c *Command) Runnable() bool {
return c.Run != nil return c.Run != nil
} }
// checkFlags checks if the flag exists with correct format.
func checkFlags(flags map[string]bool, args []string, print func(string)) int {
num := 0 // Number of valid flags, use to cut out.
for i, f := range args {
// Check flag prefix '-'.
if !strings.HasPrefix(f, "-") {
// Not a flag, finish check process.
break
}
// Check if it a valid flag.
if v, ok := flags[f]; ok {
flags[f] = !v
if !v {
print(f)
} else {
fmt.Println("DISABLE: " + f)
}
} else {
fmt.Printf("[ERRO] Unknown flag: %s.\n", f)
return -1
}
num = i + 1
}
return num
}

View File

@ -1,11 +1,18 @@
package main package main
import ( import (
//"github.com/lunny/xorm" "github.com/lunny/xorm"
"strings" "strings"
"xorm"
) )
func unTitle(src string) string {
if src == "" {
return ""
}
return strings.ToLower(string(src[0])) + src[1:]
}
func typestring(st xorm.SQLType) string { func typestring(st xorm.SQLType) string {
t := xorm.SQLType2Type(st) t := xorm.SQLType2Type(st)
s := t.String() s := t.String()
@ -15,7 +22,7 @@ func typestring(st xorm.SQLType) string {
return s return s
} }
func tag(col *xorm.Column) string { func tag(table *xorm.Table, col *xorm.Column) string {
res := make([]string, 0) res := make([]string, 0)
if !col.Nullable { if !col.Nullable {
res = append(res, "not null") res = append(res, "not null")
@ -35,6 +42,19 @@ func tag(col *xorm.Column) string {
if col.IsUpdated { if col.IsUpdated {
res = append(res, "updated") res = append(res, "updated")
} }
for name, _ := range col.Indexes {
index := table.Indexes[name]
var uistr string
if index.Type == xorm.UniqueType {
uistr = "unique"
} else if index.Type == xorm.IndexType {
uistr = "index"
}
if index.Name != col.Name {
uistr += "(" + index.Name + ")"
}
res = append(res, uistr)
}
if len(res) > 0 { if len(res) > 0 {
return "`xorm:\"" + strings.Join(res, " ") + "\"`" return "`xorm:\"" + strings.Join(res, " ") + "\"`"

View File

@ -1,11 +1,12 @@
package main package main
import ( import (
"fmt"
//"github.com/lunny/xorm"
"bytes" "bytes"
"fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
"github.com/dvirsky/go-pylog/logging"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"go/format" "go/format"
@ -14,37 +15,56 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"text/template" "text/template"
"xorm"
) )
var CmdReverse = &Command{ var CmdReverse = &Command{
UsageLine: "reverse -m driverName datasourceName tmplpath", UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]",
Short: "reverse a db to codes", Short: "reverse a db to codes",
Long: ` Long: `
according database's tables and columns to generate codes for Go, C++ and etc. according database's tables and columns to generate codes for Go, C++ and etc.
-m Generated one go file for every table
driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres
datasourceName Database connection uri, for detail infomation please visit driver's project page
tmplPath Template dir for generated. the default templates dir has provide 1 template
generatedPath This parameter is optional, if blank, the default value is model, then will
generated all codes in model dir
`, `,
} }
func init() { func init() {
CmdReverse.Run = runReverse CmdReverse.Run = runReverse
CmdReverse.Flags = map[string]bool{} CmdReverse.Flags = map[string]bool{
"-m": false,
}
} }
func printReversePrompt(flag string) { func printReversePrompt(flag string) {
} }
type Tmpl struct { type Tmpl struct {
Table *xorm.Table Tables []*xorm.Table
Imports map[string]string Imports map[string]string
Model string Model string
} }
func runReverse(cmd *Command, args []string) { func runReverse(cmd *Command, args []string) {
num := checkFlags(cmd.Flags, args, printReversePrompt)
if num == -1 {
return
}
args = args[num:]
if len(args) < 3 { if len(args) < 3 {
fmt.Println("no") fmt.Println("no")
return return
} }
var isMultiFile bool
if _, ok := cmd.Flags["-m"]; ok {
isMultiFile = true
}
curPath, err := os.Getwd() curPath, err := os.Getwd()
if err != nil { if err != nil {
fmt.Println(curPath) fmt.Println(curPath)
@ -69,23 +89,22 @@ func runReverse(cmd *Command, args []string) {
Orm, err := xorm.NewEngine(args[0], args[1]) Orm, err := xorm.NewEngine(args[0], args[1])
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return return
} }
tables, err := Orm.DBMetas() tables, err := Orm.DBMetas()
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return return
} }
dir, err := filepath.Abs(args[2]) dir, err := filepath.Abs(args[2])
if err != nil { if err != nil {
fmt.Println(curPath) logging.Error("%v", err)
return return
} }
var isMultiFile bool = true
m := &xorm.SnakeMapper{} m := &xorm.SnakeMapper{}
filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
@ -95,7 +114,7 @@ func runReverse(cmd *Command, args []string) {
bs, err := ioutil.ReadFile(f) bs, err := ioutil.ReadFile(f)
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return err return err
} }
@ -107,7 +126,7 @@ func runReverse(cmd *Command, args []string) {
tmpl, err := t.Parse(string(bs)) tmpl, err := t.Parse(string(bs))
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return err return err
} }
@ -117,58 +136,85 @@ func runReverse(cmd *Command, args []string) {
ext := path.Ext(newFileName) ext := path.Ext(newFileName)
if !isMultiFile { if !isMultiFile {
w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0700) w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600)
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return err return err
} }
}
for _, table := range tables {
// imports
imports := make(map[string]string) imports := make(map[string]string)
tbls := make([]*xorm.Table, 0)
for _, table := range tables {
for _, col := range table.Columns { for _, col := range table.Columns {
if typestring(col.SQLType) == "time.Time" { if typestring(col.SQLType) == "time.Time" {
imports["time.Time"] = "time.Time" imports["time"] = "time"
} }
} }
tbls = append(tbls, table)
if isMultiFile {
w, err = os.OpenFile(path.Join(genDir, m.Table2Obj(table.Name)+ext), os.O_RDWR|os.O_CREATE, 0700)
if err != nil {
fmt.Println(err)
return err
}
} }
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
t := &Tmpl{Table: table, Imports: imports, Model: model} t := &Tmpl{Tables: tbls, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t) err = tmpl.Execute(newbytes, t)
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return err return err
} }
tplcontent, err := ioutil.ReadAll(newbytes) tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return err return err
} }
source, err := format.Source(tplcontent) source, err := format.Source(tplcontent)
if err != nil { if err != nil {
fmt.Println(err) logging.Error("%v", err)
return err return err
} }
w.WriteString(string(source)) w.WriteString(string(source))
if isMultiFile {
w.Close() w.Close()
} else {
for _, table := range tables {
// imports
imports := make(map[string]string)
for _, col := range table.Columns {
if typestring(col.SQLType) == "time.Time" {
imports["time"] = "time"
} }
} }
if !isMultiFile {
w, err := os.OpenFile(path.Join(genDir, unTitle(m.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
logging.Error("%v", err)
return err
}
newbytes := bytes.NewBufferString("")
t := &Tmpl{Tables: []*xorm.Table{table}, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t)
if err != nil {
logging.Error("%v", err)
return err
}
tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil {
logging.Error("%v", err)
return err
}
source, err := format.Source(tplcontent)
if err != nil {
logging.Error("%v", err)
return err
}
w.WriteString(string(source))
w.Close() w.Close()
} }
}
return nil return nil
}) })

View File

@ -1,11 +1,14 @@
package {{.Model}} package {{.Model}}
import ( import (
"github.com/lunny/xorm"
{{range .Imports}}"{{.}}"{{end}} {{range .Imports}}"{{.}}"{{end}}
) )
type {{Mapper .Table.Name}} struct { {{range .Tables}}
{{range .Table.Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag .}} type {{Mapper .Name}} struct {
{{$table := .}}
{{range .Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag $table .}}
{{end}} {{end}}
} }
{{end}}

View File

@ -2,6 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"github.com/dvirsky/go-pylog/logging"
"io" "io"
"os" "os"
"runtime" "runtime"
@ -28,6 +29,7 @@ func init() {
} }
func main() { func main() {
logging.SetLevel(logging.ALL)
// Check length of arguments. // Check length of arguments.
args := os.Args[1:] args := os.Args[1:]
if len(args) < 1 { if len(args) < 1 {