reverse tool improved

This commit is contained in:
Lunny Xiao 2013-10-14 15:40:24 +08:00
parent e8ed91be2b
commit 3604f1a593
16 changed files with 304 additions and 76 deletions

View File

@ -6,7 +6,7 @@ exit 1
fi fi
CURDIR=`pwd` CURDIR=`pwd`
NEWPATH="$GOPATH/src/${PWD##*/}" NEWPATH="$GOPATH/src/github.com/lunny/${PWD##*/}"
if [ ! -d "$NEWPATH" ]; then if [ ! -d "$NEWPATH" ]; then
ln -s $CURDIR $NEWPATH ln -s $CURDIR $NEWPATH
fi fi

View File

@ -282,7 +282,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
if indexName == "PRIMARY" { if indexName == "PRIMARY" {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
}
var index *Index var index *Index
var ok bool var ok bool

View File

@ -272,7 +272,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
}
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {

View File

@ -438,8 +438,9 @@ func (statement *Statement) convertIdSql(sql string) string {
if len(sqls) != 2 { if len(sqls) != 2 {
return "" return ""
} }
return fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()),
statement.Engine.Quote(col.Name), sqls[1]) statement.Engine.Quote(col.Name), sqls[1])
return newsql
} }
} }
return "" return ""
@ -535,14 +536,14 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
cacher := table.Cacher cacher := table.Cacher
ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args)
if err != nil { if err != nil {
session.Engine.LogError(err) //session.Engine.LogError(err)
resultsSlice, err := session.query(newsql, args...) resultsSlice, err := session.query(newsql, args...)
if err != nil { if err != nil {
return err return err
} }
// 查询数目太大,采用缓存将不是一个很好的方式。 // 查询数目太大,采用缓存将不是一个很好的方式。
if len(resultsSlice) > 100 { if len(resultsSlice) > 100 {
session.Engine.LogDebug("[xorm:cacheFind] ids > 100, no cache") session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 100, no cache", len(resultsSlice))
return ErrCacheFailed return ErrCacheFailed
} }
@ -574,6 +575,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
ididxes := make(map[int64]int)
var idxes []int = make([]int, 0) var idxes []int = make([]int, 0)
var ides []interface{} = make([]interface{}, 0) var ides []interface{} = make([]interface{}, 0)
var temps []interface{} = make([]interface{}, len(ids)) var temps []interface{} = make([]interface{}, len(ids))
@ -583,6 +585,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
if bean == nil { if bean == nil {
idxes = append(idxes, idx) idxes = append(idxes, idx)
ides = append(ides, id) ides = append(ides, id)
ididxes[id] = idx
} else { } else {
session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean)
temps[idx] = bean temps[idx] = bean
@ -597,10 +600,13 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
beans := slices.Interface() beans := slices.Interface()
//beans := reflect.New(sliceValue.Type()).Interface() //beans := reflect.New(sliceValue.Type()).Interface()
err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans)
//err = newSession.In("(id)", ides...).NoCache().Find(beans)
if err != nil { if err != nil {
return err return err
} }
pkFieldName := session.Statement.RefTable.PKColumn().FieldName
vs := reflect.Indirect(reflect.ValueOf(beans)) vs := reflect.Indirect(reflect.ValueOf(beans))
for i := 0; i < vs.Len(); i++ { for i := 0; i < vs.Len(); i++ {
rv := vs.Index(i) rv := vs.Index(i)
@ -608,8 +614,10 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
rv = rv.Addr() rv = rv.Addr()
} }
bean := rv.Interface() bean := rv.Interface()
id := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int()
//bean := vs.Index(i).Addr().Interface() //bean := vs.Index(i).Addr().Interface()
temps[idxes[i]] = bean temps[ididxes[id]] = bean
//temps[idxes[i]] = bean
session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, ides[i], bean) session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, ides[i], bean)
cacher.PutBean(tableName, ides[i].(int64), bean) cacher.PutBean(tableName, ides[i].(int64), bean)
} }
@ -617,7 +625,10 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
for j := 0; j < len(temps); j++ { for j := 0; j < len(temps); j++ {
bean := temps[j] bean := temps[j]
if bean != nil { if bean == nil {
session.Engine.LogError("[xorm:cacheFind] cache error:", tableName, ides[j], bean)
return errors.New("cache error")
}
if sliceValue.Kind() == reflect.Slice { if sliceValue.Kind() == reflect.Slice {
if t.Kind() == reflect.Ptr { if t.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean))) sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean)))
@ -637,13 +648,13 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean)))
} }
} }
} else { /*} else {
session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j]) session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j])
cacher.DelBean(tableName, ids[j]) cacher.DelBean(tableName, ids[j])
session.Engine.LogDebug("[xorm:cacheFind] cache clear:", tableName) session.Engine.LogDebug("[xorm:cacheFind] cache clear:", tableName)
cacher.ClearIds(tableName) cacher.ClearIds(tableName)
} }*/
} }
return nil return nil

View File

@ -111,6 +111,7 @@ func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true
for idx, field := range fields { for idx, field := range fields {
if idx == 0 { if idx == 0 {
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
@ -193,7 +194,12 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else {
index.Name = indexName
}
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType index.Type = UniqueType
@ -209,6 +215,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
for _, col := range colIndexes { for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []")) index.Cols = append(index.Cols, strings.Trim(col, "` []"))
} }
indexes[index.Name] = index
} }
return indexes, nil return indexes, nil

View File

@ -144,7 +144,8 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
} }
func SQLType2Type(st SQLType) reflect.Type { func SQLType2Type(st SQLType) reflect.Type {
switch st.Name { name := strings.ToUpper(st.Name)
switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1) return reflect.TypeOf(1)
case BigInt, BigSerial: case BigInt, BigSerial:

View File

@ -1 +1,65 @@
package main package main
import (
//"fmt"
"github.com/lunny/xorm"
"strings"
"text/template"
)
var (
CPlusTmpl LangTmpl = LangTmpl{
template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": cPlusTypeStr,
"UnTitle": unTitle,
},
nil,
genCPlusImports,
}
)
func cPlusTypeStr(col *xorm.Column) string {
tp := col.SQLType
name := strings.ToUpper(tp.Name)
switch name {
case xorm.Bit, xorm.TinyInt, xorm.SmallInt, xorm.MediumInt, xorm.Int, xorm.Integer, xorm.Serial:
return "int"
case xorm.BigInt, xorm.BigSerial:
return "__int64"
case xorm.Char, xorm.Varchar, xorm.TinyText, xorm.Text, xorm.MediumText, xorm.LongText:
return "tstring"
case xorm.Date, xorm.DateTime, xorm.Time, xorm.TimeStamp:
return "time_t"
case xorm.Decimal, xorm.Numeric:
return "tstring"
case xorm.Real, xorm.Float:
return "float"
case xorm.Double:
return "double"
case xorm.TinyBlob, xorm.Blob, xorm.MediumBlob, xorm.LongBlob, xorm.Bytea:
return "tstring"
case xorm.Bool:
return "bool"
default:
return "tstring"
}
return ""
}
func genCPlusImports(tables []*xorm.Table) map[string]string {
imports := make(map[string]string)
for _, table := range tables {
for _, col := range table.Columns {
switch cPlusTypeStr(col) {
case "time_t":
imports[`<time.h>`] = `<time.h>`
case "tstring":
imports["<string>"] = "<string>"
//case "__int64":
// imports[""] = ""
}
}
}
return imports
}

View File

@ -2,18 +2,49 @@ package main
import ( import (
"github.com/lunny/xorm" "github.com/lunny/xorm"
"go/format"
"strings" "strings"
"text/template"
) )
func unTitle(src string) string { var (
if src == "" { GoLangTmpl LangTmpl = LangTmpl{
return "" template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": typestring,
"Tag": tag,
"UnTitle": unTitle,
},
formatGo,
genGoImports,
} }
)
return strings.ToLower(string(src[0])) + src[1:] func formatGo(src string) (string, error) {
source, err := format.Source([]byte(src))
if err != nil {
return "", err
}
return string(source), nil
} }
func typestring(st xorm.SQLType) string { func genGoImports(tables []*xorm.Table) map[string]string {
imports := make(map[string]string)
for _, table := range tables {
for _, col := range table.Columns {
if typestring(col) == "time.Time" {
imports["time"] = "time"
}
}
}
return imports
}
func typestring(col *xorm.Column) string {
st := col.SQLType
if col.IsPrimaryKey {
return "int64"
}
t := xorm.SQLType2Type(st) t := xorm.SQLType2Type(st)
s := t.String() s := t.String()
if s == "[]uint8" { if s == "[]uint8" {
@ -23,19 +54,26 @@ func typestring(st xorm.SQLType) string {
} }
func tag(table *xorm.Table, col *xorm.Column) string { func tag(table *xorm.Table, col *xorm.Column) string {
isNameId := (mapper.Table2Obj(col.Name) == "Id")
res := make([]string, 0) res := make([]string, 0)
if !col.Nullable { if !col.Nullable {
if !isNameId {
res = append(res, "not null") res = append(res, "not null")
} }
}
if col.IsPrimaryKey { if col.IsPrimaryKey {
if !isNameId {
res = append(res, "pk") res = append(res, "pk")
} }
}
if col.Default != "" { if col.Default != "" {
res = append(res, "default "+col.Default) res = append(res, "default "+col.Default)
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
if !isNameId {
res = append(res, "autoincr") res = append(res, "autoincr")
} }
}
if col.IsCreated { if col.IsCreated {
res = append(res, "created") res = append(res, "created")
} }

47
xorm/lang.go Normal file
View File

@ -0,0 +1,47 @@
package main
import (
"github.com/lunny/xorm"
"io/ioutil"
"strings"
"text/template"
)
type LangTmpl struct {
Funcs template.FuncMap
Formater func(string) (string, error)
GenImports func([]*xorm.Table) map[string]string
}
var (
mapper = &xorm.SnakeMapper{}
langTmpls = map[string]LangTmpl{
"go": GoLangTmpl,
"c++": CPlusTmpl,
}
)
func loadConfig(f string) map[string]string {
bts, err := ioutil.ReadFile(f)
if err != nil {
return nil
}
configs := make(map[string]string)
lines := strings.Split(string(bts), "\n")
for _, line := range lines {
line = strings.TrimRight(line, "\r")
vs := strings.Split(line, "=")
if len(vs) == 2 {
configs[strings.TrimSpace(vs[0])] = strings.TrimSpace(vs[1])
}
}
return configs
}
func unTitle(src string) string {
if src == "" {
return ""
}
return strings.ToLower(string(src[0])) + src[1:]
}

View File

@ -9,7 +9,6 @@ import (
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"go/format"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
@ -36,6 +35,7 @@ func init() {
CmdReverse.Run = runReverse CmdReverse.Run = runReverse
CmdReverse.Flags = map[string]bool{ CmdReverse.Flags = map[string]bool{
"-m": false, "-m": false,
"-l": false,
} }
} }
@ -85,6 +85,29 @@ func runReverse(cmd *Command, args []string) {
genDir = path.Join(curPath, model) genDir = path.Join(curPath, model)
} }
dir, err := filepath.Abs(args[2])
if err != nil {
logging.Error("%v", err)
return
}
var langTmpl LangTmpl
var ok bool
var lang string = "go"
cfgPath := path.Join(dir, "config")
info, err := os.Stat(cfgPath)
var configs map[string]string
if err == nil && !info.IsDir() {
configs = loadConfig(cfgPath)
lang = configs["lang"]
}
if langTmpl, ok = langTmpls[lang]; !ok {
fmt.Println("Unsupported lang", lang)
return
}
os.MkdirAll(genDir, os.ModePerm) os.MkdirAll(genDir, os.ModePerm)
Orm, err := xorm.NewEngine(args[0], args[1]) Orm, err := xorm.NewEngine(args[0], args[1])
@ -99,19 +122,15 @@ func runReverse(cmd *Command, args []string) {
return return
} }
dir, err := filepath.Abs(args[2])
if err != nil {
logging.Error("%v", err)
return
}
m := &xorm.SnakeMapper{}
filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
if info.IsDir() { if info.IsDir() {
return nil return nil
} }
if info.Name() == "config" {
return nil
}
bs, err := ioutil.ReadFile(f) bs, err := ioutil.ReadFile(f)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
@ -119,10 +138,7 @@ func runReverse(cmd *Command, args []string) {
} }
t := template.New(f) t := template.New(f)
t.Funcs(template.FuncMap{"Mapper": m.Table2Obj, t.Funcs(langTmpl.Funcs)
"Type": typestring,
"Tag": tag,
})
tmpl, err := t.Parse(string(bs)) tmpl, err := t.Parse(string(bs))
if err != nil { if err != nil {
@ -142,14 +158,10 @@ func runReverse(cmd *Command, args []string) {
return err return err
} }
imports := make(map[string]string) imports := langTmpl.GenImports(tables)
tbls := make([]*xorm.Table, 0) tbls := make([]*xorm.Table, 0)
for _, table := range tables { for _, table := range tables {
for _, col := range table.Columns {
if typestring(col.SQLType) == "time.Time" {
imports["time"] = "time"
}
}
tbls = append(tbls, table) tbls = append(tbls, table)
} }
@ -167,25 +179,26 @@ func runReverse(cmd *Command, args []string) {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
source, err := format.Source(tplcontent) var source string
if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent))
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
} else {
source = string(tplcontent)
}
w.WriteString(string(source)) w.WriteString(source)
w.Close() w.Close()
} else { } else {
for _, table := range tables { for _, table := range tables {
// imports // imports
imports := make(map[string]string) tbs := []*xorm.Table{table}
for _, col := range table.Columns { imports := langTmpl.GenImports(tbs)
if typestring(col.SQLType) == "time.Time" {
imports["time"] = "time"
}
}
w, err := os.OpenFile(path.Join(genDir, unTitle(m.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
@ -193,7 +206,7 @@ func runReverse(cmd *Command, args []string) {
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
t := &Tmpl{Tables: []*xorm.Table{table}, Imports: imports, Model: model} t := &Tmpl{Tables: tbs, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t) err = tmpl.Execute(newbytes, t)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
@ -205,13 +218,18 @@ func runReverse(cmd *Command, args []string) {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
source, err := format.Source(tplcontent) var source string
if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent))
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
} else {
source = string(tplcontent)
}
w.WriteString(string(source)) w.WriteString(source)
w.Close() w.Close()
} }
} }

View File

@ -0,0 +1,21 @@
{{ range .Imports}}
#include {{.}}
{{ end }}
{{range .Tables}}class {{Mapper .Name}} {
{{$table := .}}
public:
{{range .Columns}}{{$name := Mapper .Name}} {{Type .}} Get{{Mapper .Name}}() {
return this->m_{{UnTitle $name}};
}
void Set{{$name}}({{Type .}} {{UnTitle $name}}) {
this->m_{{UnTitle $name}} = {{UnTitle $name}};
}
{{end}}private:
{{range .Columns}}{{$name := Mapper .Name}} {{Type .}} m_{{UnTitle $name}};
{{end}}
}
{{end}}

View File

@ -0,0 +1 @@
lang=c++

1
xorm/templates/go/config Normal file
View File

@ -0,0 +1 @@
lang=go

View File

@ -7,7 +7,7 @@ import (
{{range .Tables}} {{range .Tables}}
type {{Mapper .Name}} struct { type {{Mapper .Name}} struct {
{{$table := .}} {{$table := .}}
{{range .Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag $table .}} {{range .Columns}} {{Mapper .Name}} {{Type .}}
{{end}} {{end}}
} }

View File

@ -0,0 +1 @@
lang=go

View File

@ -0,0 +1,14 @@
package {{.Model}}
import (
{{range .Imports}}"{{.}}"{{end}}
)
{{range .Tables}}
type {{Mapper .Name}} struct {
{{$table := .}}
{{range .Columns}} {{Mapper .Name}} {{Type .}} {{Tag $table .}}
{{end}}
}
{{end}}