diff --git a/install b/install index 1ccc8fc6..53326e0b 100755 --- a/install +++ b/install @@ -6,7 +6,7 @@ exit 1 fi CURDIR=`pwd` -NEWPATH="$GOPATH/src/${PWD##*/}" +NEWPATH="$GOPATH/src/github.com/lunny/${PWD##*/}" if [ ! -d "$NEWPATH" ]; then ln -s $CURDIR $NEWPATH fi diff --git a/mysql.go b/mysql.go index a75179e9..4d436616 100644 --- a/mysql.go +++ b/mysql.go @@ -282,7 +282,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { if indexName == "PRIMARY" { continue } - indexName = indexName[5+len(tableName) : len(indexName)] + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { + indexName = indexName[5+len(tableName) : len(indexName)] + } var index *Index var ok bool diff --git a/postgres.go b/postgres.go index af873ff8..4cbc84aa 100644 --- a/postgres.go +++ b/postgres.go @@ -272,7 +272,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { if strings.HasSuffix(indexName, "_pkey") { continue } - indexName = indexName[5+len(tableName) : len(indexName)] + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { + indexName = indexName[5+len(tableName) : len(indexName)] + } index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} for _, colName := range colNames { diff --git a/session.go b/session.go index 9352ad5a..21cc539e 100644 --- a/session.go +++ b/session.go @@ -438,8 +438,9 @@ func (statement *Statement) convertIdSql(sql string) string { if len(sqls) != 2 { return "" } - return fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), + newsql := fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(col.Name), sqls[1]) + return newsql } } return "" @@ -535,14 +536,14 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter cacher := table.Cacher ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) if err != nil { - session.Engine.LogError(err) + //session.Engine.LogError(err) resultsSlice, err := session.query(newsql, args...) if err != nil { return err } // 查询数目太大,采用缓存将不是一个很好的方式。 if len(resultsSlice) > 100 { - session.Engine.LogDebug("[xorm:cacheFind] ids > 100, no cache") + session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 100, no cache", len(resultsSlice)) return ErrCacheFailed } @@ -574,6 +575,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + ididxes := make(map[int64]int) var idxes []int = make([]int, 0) var ides []interface{} = make([]interface{}, 0) var temps []interface{} = make([]interface{}, len(ids)) @@ -583,6 +585,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter if bean == nil { idxes = append(idxes, idx) ides = append(ides, id) + ididxes[id] = idx } else { session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) temps[idx] = bean @@ -597,10 +600,13 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter beans := slices.Interface() //beans := reflect.New(sliceValue.Type()).Interface() err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) + //err = newSession.In("(id)", ides...).NoCache().Find(beans) if err != nil { return err } + pkFieldName := session.Statement.RefTable.PKColumn().FieldName + vs := reflect.Indirect(reflect.ValueOf(beans)) for i := 0; i < vs.Len(); i++ { rv := vs.Index(i) @@ -608,8 +614,10 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter rv = rv.Addr() } bean := rv.Interface() + id := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() //bean := vs.Index(i).Addr().Interface() - temps[idxes[i]] = bean + temps[ididxes[id]] = bean + //temps[idxes[i]] = bean session.Engine.LogDebug("[xorm:cacheFind] cache bean:", tableName, ides[i], bean) cacher.PutBean(tableName, ides[i].(int64), bean) } @@ -617,33 +625,36 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter for j := 0; j < len(temps); j++ { bean := temps[j] - if bean != nil { - if sliceValue.Kind() == reflect.Slice { - if t.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean))) - } else { - sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) - } - } else if sliceValue.Kind() == reflect.Map { - var key int64 - if table.PrimaryKey != "" { - key = ids[j] - } else { - key = int64(j) - } - if t.Kind() == reflect.Ptr { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean)) - } else { - sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) - } + if bean == nil { + session.Engine.LogError("[xorm:cacheFind] cache error:", tableName, ides[j], bean) + return errors.New("cache error") + } + if sliceValue.Kind() == reflect.Slice { + if t.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean))) + } else { + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) } - } else { + } else if sliceValue.Kind() == reflect.Map { + var key int64 + if table.PrimaryKey != "" { + key = ids[j] + } else { + key = int64(j) + } + if t.Kind() == reflect.Ptr { + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean)) + } else { + sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean))) + } + } + /*} else { session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j]) cacher.DelBean(tableName, ids[j]) session.Engine.LogDebug("[xorm:cacheFind] cache clear:", tableName) cacher.ClearIds(tableName) - } + }*/ } return nil diff --git a/sqlite3.go b/sqlite3.go index 26724695..ee2663ac 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -111,6 +111,7 @@ func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { fields := strings.Fields(strings.TrimSpace(colStr)) col := new(Column) col.Indexes = make(map[string]bool) + col.Nullable = true for idx, field := range fields { if idx == 0 { col.Name = strings.Trim(field, "`[] ") @@ -193,7 +194,12 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { nNStart := strings.Index(sql, "INDEX") nNEnd := strings.Index(sql, "ON") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") - index.Name = indexName[5+len(tableName) : len(indexName)] + //fmt.Println(indexName) + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { + index.Name = indexName[5+len(tableName) : len(indexName)] + } else { + index.Name = indexName + } if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { index.Type = UniqueType @@ -209,6 +215,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { for _, col := range colIndexes { index.Cols = append(index.Cols, strings.Trim(col, "` []")) } + indexes[index.Name] = index } return indexes, nil diff --git a/table.go b/table.go index 294cb438..b7303808 100644 --- a/table.go +++ b/table.go @@ -144,7 +144,8 @@ func Type2SQLType(t reflect.Type) (st SQLType) { } func SQLType2Type(st SQLType) reflect.Type { - switch st.Name { + name := strings.ToUpper(st.Name) + switch name { case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: return reflect.TypeOf(1) case BigInt, BigSerial: diff --git a/xorm/c++.go b/xorm/c++.go index 06ab7d0f..1a759d01 100644 --- a/xorm/c++.go +++ b/xorm/c++.go @@ -1 +1,65 @@ package main + +import ( + //"fmt" + "github.com/lunny/xorm" + "strings" + "text/template" +) + +var ( + CPlusTmpl LangTmpl = LangTmpl{ + template.FuncMap{"Mapper": mapper.Table2Obj, + "Type": cPlusTypeStr, + "UnTitle": unTitle, + }, + nil, + genCPlusImports, + } +) + +func cPlusTypeStr(col *xorm.Column) string { + tp := col.SQLType + name := strings.ToUpper(tp.Name) + switch name { + case xorm.Bit, xorm.TinyInt, xorm.SmallInt, xorm.MediumInt, xorm.Int, xorm.Integer, xorm.Serial: + return "int" + case xorm.BigInt, xorm.BigSerial: + return "__int64" + case xorm.Char, xorm.Varchar, xorm.TinyText, xorm.Text, xorm.MediumText, xorm.LongText: + return "tstring" + case xorm.Date, xorm.DateTime, xorm.Time, xorm.TimeStamp: + return "time_t" + case xorm.Decimal, xorm.Numeric: + return "tstring" + case xorm.Real, xorm.Float: + return "float" + case xorm.Double: + return "double" + case xorm.TinyBlob, xorm.Blob, xorm.MediumBlob, xorm.LongBlob, xorm.Bytea: + return "tstring" + case xorm.Bool: + return "bool" + default: + return "tstring" + } + return "" +} + +func genCPlusImports(tables []*xorm.Table) map[string]string { + imports := make(map[string]string) + + for _, table := range tables { + for _, col := range table.Columns { + switch cPlusTypeStr(col) { + case "time_t": + imports[``] = `` + case "tstring": + imports[""] = "" + //case "__int64": + // imports[""] = "" + } + } + } + return imports +} diff --git a/xorm/go.go b/xorm/go.go index 866411dd..74e2acb5 100644 --- a/xorm/go.go +++ b/xorm/go.go @@ -2,18 +2,49 @@ package main import ( "github.com/lunny/xorm" + "go/format" "strings" + "text/template" ) -func unTitle(src string) string { - if src == "" { - return "" +var ( + GoLangTmpl LangTmpl = LangTmpl{ + template.FuncMap{"Mapper": mapper.Table2Obj, + "Type": typestring, + "Tag": tag, + "UnTitle": unTitle, + }, + formatGo, + genGoImports, } +) - return strings.ToLower(string(src[0])) + src[1:] +func formatGo(src string) (string, error) { + source, err := format.Source([]byte(src)) + if err != nil { + return "", err + } + return string(source), nil } -func typestring(st xorm.SQLType) string { +func genGoImports(tables []*xorm.Table) map[string]string { + imports := make(map[string]string) + + for _, table := range tables { + for _, col := range table.Columns { + if typestring(col) == "time.Time" { + imports["time"] = "time" + } + } + } + return imports +} + +func typestring(col *xorm.Column) string { + st := col.SQLType + if col.IsPrimaryKey { + return "int64" + } t := xorm.SQLType2Type(st) s := t.String() if s == "[]uint8" { @@ -23,18 +54,25 @@ func typestring(st xorm.SQLType) string { } func tag(table *xorm.Table, col *xorm.Column) string { + isNameId := (mapper.Table2Obj(col.Name) == "Id") res := make([]string, 0) if !col.Nullable { - res = append(res, "not null") + if !isNameId { + res = append(res, "not null") + } } if col.IsPrimaryKey { - res = append(res, "pk") + if !isNameId { + res = append(res, "pk") + } } if col.Default != "" { res = append(res, "default "+col.Default) } if col.IsAutoIncrement { - res = append(res, "autoincr") + if !isNameId { + res = append(res, "autoincr") + } } if col.IsCreated { res = append(res, "created") diff --git a/xorm/lang.go b/xorm/lang.go new file mode 100644 index 00000000..d7f02d20 --- /dev/null +++ b/xorm/lang.go @@ -0,0 +1,47 @@ +package main + +import ( + "github.com/lunny/xorm" + "io/ioutil" + "strings" + "text/template" +) + +type LangTmpl struct { + Funcs template.FuncMap + Formater func(string) (string, error) + GenImports func([]*xorm.Table) map[string]string +} + +var ( + mapper = &xorm.SnakeMapper{} + langTmpls = map[string]LangTmpl{ + "go": GoLangTmpl, + "c++": CPlusTmpl, + } +) + +func loadConfig(f string) map[string]string { + bts, err := ioutil.ReadFile(f) + if err != nil { + return nil + } + configs := make(map[string]string) + lines := strings.Split(string(bts), "\n") + for _, line := range lines { + line = strings.TrimRight(line, "\r") + vs := strings.Split(line, "=") + if len(vs) == 2 { + configs[strings.TrimSpace(vs[0])] = strings.TrimSpace(vs[1]) + } + } + return configs +} + +func unTitle(src string) string { + if src == "" { + return "" + } + + return strings.ToLower(string(src[0])) + src[1:] +} diff --git a/xorm/reverse.go b/xorm/reverse.go index 2a340ee3..8725143c 100644 --- a/xorm/reverse.go +++ b/xorm/reverse.go @@ -9,7 +9,6 @@ import ( "github.com/lunny/xorm" _ "github.com/mattn/go-sqlite3" _ "github.com/ziutek/mymysql/godrv" - "go/format" "io/ioutil" "os" "path" @@ -36,6 +35,7 @@ func init() { CmdReverse.Run = runReverse CmdReverse.Flags = map[string]bool{ "-m": false, + "-l": false, } } @@ -85,6 +85,29 @@ func runReverse(cmd *Command, args []string) { genDir = path.Join(curPath, model) } + dir, err := filepath.Abs(args[2]) + if err != nil { + logging.Error("%v", err) + return + } + + var langTmpl LangTmpl + var ok bool + var lang string = "go" + + cfgPath := path.Join(dir, "config") + info, err := os.Stat(cfgPath) + var configs map[string]string + if err == nil && !info.IsDir() { + configs = loadConfig(cfgPath) + lang = configs["lang"] + } + + if langTmpl, ok = langTmpls[lang]; !ok { + fmt.Println("Unsupported lang", lang) + return + } + os.MkdirAll(genDir, os.ModePerm) Orm, err := xorm.NewEngine(args[0], args[1]) @@ -99,19 +122,15 @@ func runReverse(cmd *Command, args []string) { return } - dir, err := filepath.Abs(args[2]) - if err != nil { - logging.Error("%v", err) - return - } - - m := &xorm.SnakeMapper{} - filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { if info.IsDir() { return nil } + if info.Name() == "config" { + return nil + } + bs, err := ioutil.ReadFile(f) if err != nil { logging.Error("%v", err) @@ -119,10 +138,7 @@ func runReverse(cmd *Command, args []string) { } t := template.New(f) - t.Funcs(template.FuncMap{"Mapper": m.Table2Obj, - "Type": typestring, - "Tag": tag, - }) + t.Funcs(langTmpl.Funcs) tmpl, err := t.Parse(string(bs)) if err != nil { @@ -142,14 +158,10 @@ func runReverse(cmd *Command, args []string) { return err } - imports := make(map[string]string) + imports := langTmpl.GenImports(tables) + tbls := make([]*xorm.Table, 0) for _, table := range tables { - for _, col := range table.Columns { - if typestring(col.SQLType) == "time.Time" { - imports["time"] = "time" - } - } tbls = append(tbls, table) } @@ -167,25 +179,26 @@ func runReverse(cmd *Command, args []string) { logging.Error("%v", err) return err } - source, err := format.Source(tplcontent) - if err != nil { - logging.Error("%v", err) - return err + var source string + if langTmpl.Formater != nil { + source, err = langTmpl.Formater(string(tplcontent)) + if err != nil { + logging.Error("%v", err) + return err + } + } else { + source = string(tplcontent) } - w.WriteString(string(source)) + w.WriteString(source) w.Close() } else { for _, table := range tables { // imports - imports := make(map[string]string) - for _, col := range table.Columns { - if typestring(col.SQLType) == "time.Time" { - imports["time"] = "time" - } - } + tbs := []*xorm.Table{table} + imports := langTmpl.GenImports(tbs) - w, err := os.OpenFile(path.Join(genDir, unTitle(m.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) + w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) if err != nil { logging.Error("%v", err) return err @@ -193,7 +206,7 @@ func runReverse(cmd *Command, args []string) { newbytes := bytes.NewBufferString("") - t := &Tmpl{Tables: []*xorm.Table{table}, Imports: imports, Model: model} + t := &Tmpl{Tables: tbs, Imports: imports, Model: model} err = tmpl.Execute(newbytes, t) if err != nil { logging.Error("%v", err) @@ -205,13 +218,18 @@ func runReverse(cmd *Command, args []string) { logging.Error("%v", err) return err } - source, err := format.Source(tplcontent) - if err != nil { - logging.Error("%v", err) - return err + var source string + if langTmpl.Formater != nil { + source, err = langTmpl.Formater(string(tplcontent)) + if err != nil { + logging.Error("%v", err) + return err + } + } else { + source = string(tplcontent) } - w.WriteString(string(source)) + w.WriteString(source) w.Close() } } diff --git a/xorm/templates/c++/class.h.tpl b/xorm/templates/c++/class.h.tpl new file mode 100644 index 00000000..50a32ff8 --- /dev/null +++ b/xorm/templates/c++/class.h.tpl @@ -0,0 +1,21 @@ +{{ range .Imports}} +#include {{.}} +{{ end }} + +{{range .Tables}}class {{Mapper .Name}} { +{{$table := .}} +public: +{{range .Columns}}{{$name := Mapper .Name}} {{Type .}} Get{{Mapper .Name}}() { + return this->m_{{UnTitle $name}}; + } + + void Set{{$name}}({{Type .}} {{UnTitle $name}}) { + this->m_{{UnTitle $name}} = {{UnTitle $name}}; + } + +{{end}}private: +{{range .Columns}}{{$name := Mapper .Name}} {{Type .}} m_{{UnTitle $name}}; +{{end}} +} + +{{end}} \ No newline at end of file diff --git a/xorm/templates/c++/config b/xorm/templates/c++/config new file mode 100644 index 00000000..4965bae3 --- /dev/null +++ b/xorm/templates/c++/config @@ -0,0 +1 @@ +lang=c++ \ No newline at end of file diff --git a/xorm/templates/go/config b/xorm/templates/go/config new file mode 100644 index 00000000..6fdeea2b --- /dev/null +++ b/xorm/templates/go/config @@ -0,0 +1 @@ +lang=go \ No newline at end of file diff --git a/xorm/templates/go/struct.go.tpl b/xorm/templates/go/struct.go.tpl index 88955b3e..8e59d688 100644 --- a/xorm/templates/go/struct.go.tpl +++ b/xorm/templates/go/struct.go.tpl @@ -7,7 +7,7 @@ import ( {{range .Tables}} type {{Mapper .Name}} struct { {{$table := .}} -{{range .Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag $table .}} +{{range .Columns}} {{Mapper .Name}} {{Type .}} {{end}} } diff --git a/xorm/templates/goxorm/config b/xorm/templates/goxorm/config new file mode 100644 index 00000000..6fdeea2b --- /dev/null +++ b/xorm/templates/goxorm/config @@ -0,0 +1 @@ +lang=go \ No newline at end of file diff --git a/xorm/templates/goxorm/struct.go.tpl b/xorm/templates/goxorm/struct.go.tpl new file mode 100644 index 00000000..71875f7e --- /dev/null +++ b/xorm/templates/goxorm/struct.go.tpl @@ -0,0 +1,14 @@ +package {{.Model}} + +import ( + {{range .Imports}}"{{.}}"{{end}} +) + +{{range .Tables}} +type {{Mapper .Name}} struct { +{{$table := .}} +{{range .Columns}} {{Mapper .Name}} {{Type .}} {{Tag $table .}} +{{end}} +} + +{{end}} \ No newline at end of file