diff --git a/engine.go b/engine.go index 4298dbf9..e9e2d683 100644 --- a/engine.go +++ b/engine.go @@ -200,6 +200,16 @@ func (engine *Engine) DBMetas() ([]*Table, error) { return nil, err } table.Indexes = indexes + + for _, index := range indexes { + for _, name := range index.Cols { + if col, ok := table.Columns[name]; ok { + col.Indexes[index.Name] = true + } else { + return nil, errors.New("Unkonwn col " + name + " in indexes") + } + } + } } return tables, nil } diff --git a/mysql.go b/mysql.go index 8f6effc9..a75179e9 100644 --- a/mysql.go +++ b/mysql.go @@ -155,6 +155,7 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { if err != nil { return nil, err } + defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { return nil, err @@ -162,10 +163,11 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { cols := make(map[string]*Column) for _, record := range res { col := new(Column) + col.Indexes = make(map[string]bool) for name, content := range record { switch name { case "COLUMN_NAME": - col.Name = string(content) + col.Name = strings.Trim(string(content), "` ") case "IS_NULLABLE": if "YES" == string(content) { col.Nullable = true @@ -225,6 +227,7 @@ func (db *mysql) GetTables() ([]*Table, error) { if err != nil { return nil, err } + defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { return nil, err @@ -236,7 +239,7 @@ func (db *mysql) GetTables() ([]*Table, error) { for name, content := range record { switch name { case "TABLE_NAME": - table.Name = string(content) + table.Name = strings.Trim(string(content), "` ") case "ENGINE": } } @@ -252,6 +255,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { if err != nil { return nil, err } + defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { return nil, err @@ -272,7 +276,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { case "INDEX_NAME": indexName = string(content) case "COLUMN_NAME": - colName = string(content) + colName = strings.Trim(string(content), "` ") } } if indexName == "PRIMARY" { diff --git a/postgres.go b/postgres.go index ce612810..af873ff8 100644 --- a/postgres.go +++ b/postgres.go @@ -1,6 +1,7 @@ package xorm import ( + "database/sql" "errors" "fmt" "strconv" @@ -141,13 +142,14 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { args := []interface{}{tableName} - s := "SELECT COLUMN_NAME, column_default, is_nullable, data_type, character_maximum_length" + - " FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + + ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" cnn, err := sql.Open(db.drivername, db.dataSourceName) if err != nil { return nil, err } + defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { return nil, err @@ -155,26 +157,128 @@ func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { cols := make(map[string]*Column) for _, record := range res { col := new(Column) - + col.Indexes = make(map[string]bool) for name, content := range record { switch name { - case "COLUMN_NAME": - col.Name = string(content) + case "column_name": + col.Name = strings.Trim(string(content), `" `) case "column_default": - if strings.HasPrefix(string(content), "") { - col.IsPrimaryKey + if strings.HasPrefix(string(content), "nextval") { + col.IsPrimaryKey = true } + case "is_nullable": + if string(content) == "YES" { + col.Nullable = true + } else { + col.Nullable = false + } + case "data_type": + ct := string(content) + switch ct { + case "character varying", "character": + col.SQLType = SQLType{Varchar, 0, 0} + case "timestamp without time zone": + col.SQLType = SQLType{DateTime, 0, 0} + case "double precision": + col.SQLType = SQLType{Double, 0, 0} + case "boolean": + col.SQLType = SQLType{Bool, 0, 0} + case "time without time zone": + col.SQLType = SQLType{Time, 0, 0} + default: + col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} + } + if _, ok := sqlTypes[col.SQLType.Name]; !ok { + return nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) + } + case "character_maximum_length": + i, err := strconv.Atoi(string(content)) + if err != nil { + return nil, errors.New("retrieve length error") + } + col.Length = i + case "numeric_precision": + case "numeric_precision_radix": } } + cols[col.Name] = col } - return nil, ErrNotImplemented + return cols, nil } func (db *postgres) GetTables() ([]*Table, error) { - return nil, ErrNotImplemented + args := []interface{}{} + s := "SELECT tablename FROM pg_tables where schemaname = 'public'" + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "tablename": + table.Name = string(content) + } + } + tables = append(tables, table) + } + return tables, nil } func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { - return nil, ErrNotImplemented + args := []interface{}{tableName} + s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" + + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName string + var colNames []string + + for name, content := range record { + switch name { + case "indexname": + indexName = strings.Trim(string(content), `" `) + case "indexdef": + c := string(content) + if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { + indexType = UniqueType + } else { + indexType = IndexType + } + cs := strings.Split(c, "(") + colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") + } + } + if strings.HasSuffix(indexName, "_pkey") { + continue + } + indexName = indexName[5+len(tableName) : len(indexName)] + + index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + for _, colName := range colNames { + index.Cols = append(index.Cols, strings.Trim(colName, `" `)) + } + indexes[index.Name] = index + } + return indexes, nil } diff --git a/sqlite3.go b/sqlite3.go index b0b6924e..26724695 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1,5 +1,11 @@ package xorm +import ( + "database/sql" + "fmt" + "strings" +) + type sqlite3 struct { base } @@ -69,24 +75,141 @@ func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ? and sql like '%`" + colName + "`%'", args + fmt.Println(tableName, colName) + sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" + fmt.Println(sql) + return sql, args } func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { - /*args := []interface{}{db.dbname, tableName} + args := []interface{}{tableName} + s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - SELECT sql FROM sqlite_master WHERE type='table' and name = 'category'; - sql := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + - " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + var sql string + for _, record := range res { + for name, content := range record { + if name == "sql" { + sql = string(content) + } + } + } - return sql, args*/ - return nil, ErrNotImplemented + nStart := strings.Index(sql, "(") + nEnd := strings.Index(sql, ")") + colCreates := strings.Split(sql[nStart+1:nEnd], ",") + cols := make(map[string]*Column) + for _, colStr := range colCreates { + fields := strings.Fields(strings.TrimSpace(colStr)) + col := new(Column) + col.Indexes = make(map[string]bool) + for idx, field := range fields { + if idx == 0 { + col.Name = strings.Trim(field, "`[] ") + continue + } else if idx == 1 { + col.SQLType = SQLType{field, 0, 0} + } + switch field { + case "PRIMARY": + col.IsPrimaryKey = true + case "AUTOINCREMENT": + col.IsAutoIncrement = true + case "NULL": + if fields[idx-1] == "NOT" { + col.Nullable = false + } else { + col.Nullable = true + } + } + } + cols[col.Name] = col + } + return cols, nil } func (db *sqlite3) GetTables() ([]*Table, error) { - return nil, ErrNotImplemented + args := []interface{}{} + s := "SELECT name FROM sqlite_master WHERE type='table'" + + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "name": + table.Name = string(content) + } + } + if table.Name == "sqlite_sequence" { + continue + } + tables = append(tables, table) + } + return tables, nil } func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { - return nil, ErrNotImplemented + args := []interface{}{tableName} + s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + indexes := make(map[string]*Index, 0) + for _, record := range res { + var sql string + index := new(Index) + for name, content := range record { + if name == "sql" { + sql = string(content) + } + } + + nNStart := strings.Index(sql, "INDEX") + nNEnd := strings.Index(sql, "ON") + indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") + index.Name = indexName[5+len(tableName) : len(indexName)] + + if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { + index.Type = UniqueType + } else { + index.Type = IndexType + } + + nStart := strings.Index(sql, "(") + nEnd := strings.Index(sql, ")") + colIndexes := strings.Split(sql[nStart+1:nEnd], ",") + + index.Cols = make([]string, 0) + for _, col := range colIndexes { + index.Cols = append(index.Cols, strings.Trim(col, "` []")) + } + } + + return indexes, nil } diff --git a/xorm/cmd.go b/xorm/cmd.go index 527a1836..cb326f15 100644 --- a/xorm/cmd.go +++ b/xorm/cmd.go @@ -48,3 +48,31 @@ func (c *Command) Usage() { func (c *Command) Runnable() bool { return c.Run != nil } + +// checkFlags checks if the flag exists with correct format. +func checkFlags(flags map[string]bool, args []string, print func(string)) int { + num := 0 // Number of valid flags, use to cut out. + for i, f := range args { + // Check flag prefix '-'. + if !strings.HasPrefix(f, "-") { + // Not a flag, finish check process. + break + } + + // Check if it a valid flag. + if v, ok := flags[f]; ok { + flags[f] = !v + if !v { + print(f) + } else { + fmt.Println("DISABLE: " + f) + } + } else { + fmt.Printf("[ERRO] Unknown flag: %s.\n", f) + return -1 + } + num = i + 1 + } + + return num +} diff --git a/xorm/go.go b/xorm/go.go index 19bbd09c..866411dd 100644 --- a/xorm/go.go +++ b/xorm/go.go @@ -1,11 +1,18 @@ package main import ( - //"github.com/lunny/xorm" + "github.com/lunny/xorm" "strings" - "xorm" ) +func unTitle(src string) string { + if src == "" { + return "" + } + + return strings.ToLower(string(src[0])) + src[1:] +} + func typestring(st xorm.SQLType) string { t := xorm.SQLType2Type(st) s := t.String() @@ -15,7 +22,7 @@ func typestring(st xorm.SQLType) string { return s } -func tag(col *xorm.Column) string { +func tag(table *xorm.Table, col *xorm.Column) string { res := make([]string, 0) if !col.Nullable { res = append(res, "not null") @@ -35,6 +42,19 @@ func tag(col *xorm.Column) string { if col.IsUpdated { res = append(res, "updated") } + for name, _ := range col.Indexes { + index := table.Indexes[name] + var uistr string + if index.Type == xorm.UniqueType { + uistr = "unique" + } else if index.Type == xorm.IndexType { + uistr = "index" + } + if index.Name != col.Name { + uistr += "(" + index.Name + ")" + } + res = append(res, uistr) + } if len(res) > 0 { return "`xorm:\"" + strings.Join(res, " ") + "\"`" diff --git a/xorm/reverse.go b/xorm/reverse.go index 1c40f9e1..2a340ee3 100644 --- a/xorm/reverse.go +++ b/xorm/reverse.go @@ -1,11 +1,12 @@ package main import ( - "fmt" - //"github.com/lunny/xorm" "bytes" + "fmt" _ "github.com/bylevel/pq" + "github.com/dvirsky/go-pylog/logging" _ "github.com/go-sql-driver/mysql" + "github.com/lunny/xorm" _ "github.com/mattn/go-sqlite3" _ "github.com/ziutek/mymysql/godrv" "go/format" @@ -14,37 +15,56 @@ import ( "path" "path/filepath" "text/template" - "xorm" ) var CmdReverse = &Command{ - UsageLine: "reverse -m driverName datasourceName tmplpath", + UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]", Short: "reverse a db to codes", Long: ` according database's tables and columns to generate codes for Go, C++ and etc. + + -m Generated one go file for every table + driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres + datasourceName Database connection uri, for detail infomation please visit driver's project page + tmplPath Template dir for generated. the default templates dir has provide 1 template + generatedPath This parameter is optional, if blank, the default value is model, then will + generated all codes in model dir `, } func init() { CmdReverse.Run = runReverse - CmdReverse.Flags = map[string]bool{} + CmdReverse.Flags = map[string]bool{ + "-m": false, + } } func printReversePrompt(flag string) { } type Tmpl struct { - Table *xorm.Table + Tables []*xorm.Table Imports map[string]string Model string } func runReverse(cmd *Command, args []string) { + num := checkFlags(cmd.Flags, args, printReversePrompt) + if num == -1 { + return + } + args = args[num:] + if len(args) < 3 { fmt.Println("no") return } + var isMultiFile bool + if _, ok := cmd.Flags["-m"]; ok { + isMultiFile = true + } + curPath, err := os.Getwd() if err != nil { fmt.Println(curPath) @@ -69,23 +89,22 @@ func runReverse(cmd *Command, args []string) { Orm, err := xorm.NewEngine(args[0], args[1]) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return } tables, err := Orm.DBMetas() if err != nil { - fmt.Println(err) + logging.Error("%v", err) return } dir, err := filepath.Abs(args[2]) if err != nil { - fmt.Println(curPath) + logging.Error("%v", err) return } - var isMultiFile bool = true m := &xorm.SnakeMapper{} filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { @@ -95,7 +114,7 @@ func runReverse(cmd *Command, args []string) { bs, err := ioutil.ReadFile(f) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return err } @@ -107,7 +126,7 @@ func runReverse(cmd *Command, args []string) { tmpl, err := t.Parse(string(bs)) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return err } @@ -117,58 +136,85 @@ func runReverse(cmd *Command, args []string) { ext := path.Ext(newFileName) if !isMultiFile { - w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0700) + w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return err } - } - for _, table := range tables { - // imports imports := make(map[string]string) - for _, col := range table.Columns { - if typestring(col.SQLType) == "time.Time" { - imports["time.Time"] = "time.Time" - } - } - - if isMultiFile { - w, err = os.OpenFile(path.Join(genDir, m.Table2Obj(table.Name)+ext), os.O_RDWR|os.O_CREATE, 0700) - if err != nil { - fmt.Println(err) - return err + tbls := make([]*xorm.Table, 0) + for _, table := range tables { + for _, col := range table.Columns { + if typestring(col.SQLType) == "time.Time" { + imports["time"] = "time" + } } + tbls = append(tbls, table) } newbytes := bytes.NewBufferString("") - t := &Tmpl{Table: table, Imports: imports, Model: model} + t := &Tmpl{Tables: tbls, Imports: imports, Model: model} err = tmpl.Execute(newbytes, t) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return err } tplcontent, err := ioutil.ReadAll(newbytes) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return err } source, err := format.Source(tplcontent) if err != nil { - fmt.Println(err) + logging.Error("%v", err) return err } w.WriteString(string(source)) - if isMultiFile { + w.Close() + } else { + for _, table := range tables { + // imports + imports := make(map[string]string) + for _, col := range table.Columns { + if typestring(col.SQLType) == "time.Time" { + imports["time"] = "time" + } + } + + w, err := os.OpenFile(path.Join(genDir, unTitle(m.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + logging.Error("%v", err) + return err + } + + newbytes := bytes.NewBufferString("") + + t := &Tmpl{Tables: []*xorm.Table{table}, Imports: imports, Model: model} + err = tmpl.Execute(newbytes, t) + if err != nil { + logging.Error("%v", err) + return err + } + + tplcontent, err := ioutil.ReadAll(newbytes) + if err != nil { + logging.Error("%v", err) + return err + } + source, err := format.Source(tplcontent) + if err != nil { + logging.Error("%v", err) + return err + } + + w.WriteString(string(source)) w.Close() } } - if !isMultiFile { - w.Close() - } return nil }) diff --git a/xorm/templates/go/struct.go.tpl b/xorm/templates/go/struct.go.tpl index bb195d16..88955b3e 100644 --- a/xorm/templates/go/struct.go.tpl +++ b/xorm/templates/go/struct.go.tpl @@ -1,11 +1,14 @@ package {{.Model}} import ( - "github.com/lunny/xorm" {{range .Imports}}"{{.}}"{{end}} ) -type {{Mapper .Table.Name}} struct { -{{range .Table.Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag .}} +{{range .Tables}} +type {{Mapper .Name}} struct { +{{$table := .}} +{{range .Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag $table .}} {{end}} -} \ No newline at end of file +} + +{{end}} \ No newline at end of file diff --git a/xorm/xorm.go b/xorm/xorm.go index ac9789cc..2b8c486e 100644 --- a/xorm/xorm.go +++ b/xorm/xorm.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/dvirsky/go-pylog/logging" "io" "os" "runtime" @@ -28,6 +29,7 @@ func init() { } func main() { + logging.SetLevel(logging.ALL) // Check length of arguments. args := os.Args[1:] if len(args) < 1 {