From 2caed88b82eec4b1a78348415f0f784897cb871a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 12 Oct 2013 23:16:51 +0800 Subject: [PATCH] added xorm reverse tool --- base_test.go | 67 +++++++++++- engine.go | 129 +++++++++++++++-------- error.go | 13 +-- examples/sync.go | 25 ++++- mapper.go | 4 +- mymysql.go | 3 +- mysql.go | 165 +++++++++++++++++++++++++++++- postgres.go | 45 +++++++- session.go | 124 +++++++++++++++++++--- sqlite3.go | 23 ++++- statement.go | 30 +++--- table.go | 69 ++++++++----- xorm.go | 12 +-- xorm/c++.go | 1 + xorm/cmd.go | 50 +++++++++ xorm/go.go | 43 ++++++++ xorm/install.sh | 20 ++++ xorm/reverse.go | 176 ++++++++++++++++++++++++++++++++ xorm/templates/go/struct.go.tpl | 11 ++ xorm/xorm.go | 158 ++++++++++++++++++++++++++++ 20 files changed, 1052 insertions(+), 116 deletions(-) create mode 100644 xorm/c++.go create mode 100644 xorm/cmd.go create mode 100644 xorm/go.go create mode 100755 xorm/install.sh create mode 100644 xorm/reverse.go create mode 100644 xorm/templates/go/struct.go.tpl create mode 100644 xorm/xorm.go diff --git a/base_test.go b/base_test.go index 9ea2c00b..1f4a6ecc 100644 --- a/base_test.go +++ b/base_test.go @@ -3,6 +3,7 @@ package xorm import ( "errors" "fmt" + "strings" "testing" "time" ) @@ -98,7 +99,7 @@ func insert(engine *Engine, t *testing.T) { } } -func query(engine *Engine, t *testing.T) { +func testQuery(engine *Engine, t *testing.T) { sql := "select * from userinfo" results, err := engine.Query(sql) if err != nil { @@ -163,6 +164,19 @@ func insertMulti(engine *Engine, t *testing.T) { t.Error(err) panic(err) } + + users2 := []*Userinfo{ + &Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + &Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()}, + &Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()}, + } + + _, err = engine.Insert(&users2) + if err != nil { + t.Error(err) + panic(err) + } } func insertTwoTable(engine *Engine, t *testing.T) { @@ -1018,6 +1032,18 @@ func testIndexAndUnique(engine *Engine, t *testing.T) { t.Error(err) //panic(err) } + + err = engine.CreateIndexes(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } + + err = engine.CreateUniques(&IndexOrUnique{}) + if err != nil { + t.Error(err) + //panic(err) + } } type IntId struct { @@ -1042,6 +1068,12 @@ func testIntId(engine *Engine, t *testing.T) { t.Error(err) panic(err) } + + _, err = engine.Insert(&IntId{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } } func testInt32Id(engine *Engine, t *testing.T) { @@ -1056,6 +1088,31 @@ func testInt32Id(engine *Engine, t *testing.T) { t.Error(err) panic(err) } + + _, err = engine.Insert(&Int32Id{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } +} + +func testMetaInfo(engine *Engine, t *testing.T) { + tables, err := engine.DBMetas() + if err != nil { + t.Error(err) + panic(err) + } + + for _, table := range tables { + fmt.Println(table.Name) + for _, col := range table.Columns { + fmt.Println(col.String(engine.dialect)) + } + + for _, index := range table.Indexes { + fmt.Println(index.Name, index.Type, strings.Join(index.Cols, ",")) + } + } } func testAll(engine *Engine, t *testing.T) { @@ -1066,7 +1123,7 @@ func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- insert --------------") insert(engine, t) fmt.Println("-------------- query --------------") - query(engine, t) + testQuery(engine, t) fmt.Println("-------------- exec --------------") exec(engine, t) fmt.Println("-------------- insertAutoIncr --------------") @@ -1132,6 +1189,12 @@ func testAll2(engine *Engine, t *testing.T) { testCreatedAndUpdated(engine, t) fmt.Println("-------------- testIndexAndUnique --------------") testIndexAndUnique(engine, t) + fmt.Println("-------------- testIntId --------------") + //testIntId(engine, t) + fmt.Println("-------------- testInt32Id --------------") + //testInt32Id(engine, t) + fmt.Println("-------------- testMetaInfo --------------") + testMetaInfo(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/engine.go b/engine.go index ae21b44c..4298dbf9 100644 --- a/engine.go +++ b/engine.go @@ -20,7 +20,7 @@ const ( // a dialect is a driver's wrapper type dialect interface { - Init(uri string) error + Init(DriverName, DataSourceName string) error SqlType(t *Column) string SupportInsertMany() bool QuoteStr() string @@ -31,6 +31,10 @@ type dialect interface { IndexCheckSql(tableName, idxName string) (string, []interface{}) TableCheckSql(tableName string) (string, []interface{}) ColumnCheckSql(tableName, colName string) (string, []interface{}) + + GetColumns(tableName string) (map[string]*Column, error) + GetTables() ([]*Table, error) + GetIndexes(tableName string) (map[string]*Index, error) } type Engine struct { @@ -38,7 +42,7 @@ type Engine struct { TagIdentifier string DriverName string DataSourceName string - Dialect dialect + dialect dialect Tables map[reflect.Type]*Table mutex *sync.Mutex ShowSQL bool @@ -57,28 +61,28 @@ type Engine struct { // When the return is ture, then engine.Insert(&users) will // generate batch sql and exeute. func (engine *Engine) SupportInsertMany() bool { - return engine.Dialect.SupportInsertMany() + return engine.dialect.SupportInsertMany() } // Engine's database use which charactor as quote. // mysql, sqlite use ` and postgres use " func (engine *Engine) QuoteStr() string { - return engine.Dialect.QuoteStr() + return engine.dialect.QuoteStr() } // Use QuoteStr quote the string sql func (engine *Engine) Quote(sql string) string { - return engine.Dialect.QuoteStr() + sql + engine.Dialect.QuoteStr() + return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() } // A simple wrapper to dialect's SqlType method func (engine *Engine) SqlType(c *Column) string { - return engine.Dialect.SqlType(c) + return engine.dialect.SqlType(c) } // Database's autoincrement statement func (engine *Engine) AutoIncrStr() string { - return engine.Dialect.AutoIncrStr() + return engine.dialect.AutoIncrStr() } // Set engine's pool, the pool default is Go's standard library's connection pool. @@ -178,6 +182,28 @@ func (engine *Engine) NoAutoTime() *Session { return session.NoAutoTime() } +func (engine *Engine) DBMetas() ([]*Table, error) { + tables, err := engine.dialect.GetTables() + if err != nil { + return nil, err + } + + for _, table := range tables { + cols, err := engine.dialect.GetColumns(table.Name) + if err != nil { + return nil, err + } + table.Columns = cols + + indexes, err := engine.dialect.GetIndexes(table.Name) + if err != nil { + return nil, err + } + table.Indexes = indexes + } + return tables, nil +} + func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() session.IsAutoClose = true @@ -316,7 +342,7 @@ func (engine *Engine) MapType(t reflect.Type) *Table { if ormTagStr != "" { col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, - IsAutoIncrement: false, MapType: TWOSIDES} + IsAutoIncrement: false, MapType: TWOSIDES, Indexes: make(map[string]bool)} tags := strings.Split(ormTagStr, " ") if len(tags) > 0 { @@ -335,6 +361,8 @@ func (engine *Engine) MapType(t reflect.Type) *Table { table.PrimaryKey = parentTable.PrimaryKey continue } + var indexType int + var indexName string for j, key := range tags { k := strings.ToUpper(key) switch { @@ -358,37 +386,15 @@ func (engine *Engine) MapType(t reflect.Type) *Table { /*case strings.HasPrefix(k, "--"): col.Comment = k[2:len(k)]*/ case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): - indexName := k[len("INDEX")+1 : len(k)-1] - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col) - col.Index = index - } else { - index := NewIndex(indexName, false) - index.AddColumn(col) - table.AddIndex(index) - col.Index = index - } + indexType = IndexType + indexName = k[len("INDEX")+1 : len(k)-1] case k == "INDEX": - index := NewIndex(col.Name, false) - index.AddColumn(col) - table.AddIndex(index) - col.Index = index + indexType = IndexType case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): - indexName := k[len("UNIQUE")+1 : len(k)-1] - if index, ok := table.Indexes[indexName]; ok { - index.AddColumn(col) - col.Index = index - } else { - index := NewIndex(indexName, true) - index.AddColumn(col) - table.AddIndex(index) - col.Index = index - } + indexName = k[len("UNIQUE")+1 : len(k)-1] + indexType = UniqueType case k == "UNIQUE": - index := NewIndex(col.Name, true) - index.AddColumn(col) - table.AddIndex(index) - col.Index = index + indexType = UniqueType case k == "NOTNULL": col.Nullable = false case k == "NOT": @@ -432,12 +438,39 @@ func (engine *Engine) MapType(t reflect.Type) *Table { if col.Name == "" { col.Name = engine.Mapper.Obj2Table(t.Field(i).Name) } + if indexType == IndexType { + if indexName == "" { + indexName = col.Name + } + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = true + } else { + index := NewIndex(indexName, IndexType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = true + } + } else if indexType == UniqueType { + if indexName == "" { + indexName = col.Name + } + if index, ok := table.Indexes[indexName]; ok { + index.AddColumn(col.Name) + col.Indexes[index.Name] = true + } else { + index := NewIndex(indexName, UniqueType) + index.AddColumn(col.Name) + table.AddIndex(index) + col.Indexes[index.Name] = true + } + } } } else { sqlType := Type2SQLType(fieldType) col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, - sqlType.DefaultLength, sqlType.DefaultLength2, true, "", nil, false, false, - TWOSIDES, false, false, ""} + sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, + TWOSIDES, false, false, false} } if col.IsAutoIncrement { col.Nullable = false @@ -498,6 +531,20 @@ func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { return has, err } +// create indexes +func (engine *Engine) CreateIndexes(bean interface{}) error { + session := engine.NewSession() + defer session.Close() + return session.CreateIndexes(bean) +} + +// create uniques +func (engine *Engine) CreateUniques(bean interface{}) error { + session := engine.NewSession() + defer session.Close() + return session.CreateUniques(bean) +} + // If enabled cache, clear the cache bean func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { t := rType(bean) @@ -585,7 +632,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { session := engine.NewSession() session.Statement.RefTable = table defer session.Close() - if index.IsUnique { + if index.Type == UniqueType { isExist, err := session.isIndexExist(table.Name, name, true) if err != nil { return err @@ -599,7 +646,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } } - } else { + } else if index.Type == IndexType { isExist, err := session.isIndexExist(table.Name, name, false) if err != nil { return err @@ -613,6 +660,8 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } } + } else { + return errors.New("unknow index type") } } } diff --git a/error.go b/error.go index d582110f..c868173f 100644 --- a/error.go +++ b/error.go @@ -5,10 +5,11 @@ import ( ) var ( - ErrParamsType error = errors.New("params type error") - ErrTableNotFound error = errors.New("not found table") - ErrUnSupportedType error = errors.New("unsupported type error") - ErrNotExist error = errors.New("not exist error") - ErrCacheFailed error = errors.New("cache failed") - ErrNeedDeletedCond error = errors.New("delete need at least one condition") + ErrParamsType error = errors.New("Params type error") + ErrTableNotFound error = errors.New("Not found table") + ErrUnSupportedType error = errors.New("Unsupported type error") + ErrNotExist error = errors.New("Not exist error") + ErrCacheFailed error = errors.New("Cache failed") + ErrNeedDeletedCond error = errors.New("Delete need at least one condition") + ErrNotImplemented error = errors.New("Not implemented.") ) diff --git a/examples/sync.go b/examples/sync.go index 2c93a5a2..084d32f6 100644 --- a/examples/sync.go +++ b/examples/sync.go @@ -9,9 +9,14 @@ import ( ) type SyncUser struct { - Id int64 - Name string `xorm:"unique"` - Age int `xorm:"index"` + Id int64 + Name string `xorm:"unique"` + Age int `xorm:"index"` + Title string + Address string + Genre string + Area string + Date int } type SyncLoginInfo struct { @@ -61,5 +66,19 @@ func main() { if err != nil { fmt.Println(err) } + + user := &SyncUser{ + Name: "testsdf", + Age: 15, + Title: "newsfds", + Address: "fasfdsafdsaf", + Genre: "fsafd", + Area: "fafdsafd", + Date: 1000, + } + _, err = Orm.Insert(user) + if err != nil { + fmt.Println(err) + } } } diff --git a/mapper.go b/mapper.go index 224658bb..14b6ae0e 100644 --- a/mapper.go +++ b/mapper.go @@ -75,7 +75,9 @@ func titleCasedName(name string) string { switch { case upNextChar: upNextChar = false - chr -= ('a' - 'A') + if 'a' <= chr && chr <= 'z' { + chr -= ('a' - 'A') + } case chr == '_': upNextChar = true continue diff --git a/mymysql.go b/mymysql.go index d9317c2f..2413af1a 100644 --- a/mymysql.go +++ b/mymysql.go @@ -17,7 +17,8 @@ type mymysql struct { passwd string } -func (db *mymysql) Init(uri string) error { +func (db *mymysql) Init(drivername, uri string) error { + db.mysql.base.init(drivername, uri) pd := strings.SplitN(uri, "*", 2) if len(pd) == 2 { // Parse protocol part of URI diff --git a/mysql.go b/mysql.go index 080920cc..8f6effc9 100644 --- a/mysql.go +++ b/mysql.go @@ -2,14 +2,26 @@ package xorm import ( "crypto/tls" - //"fmt" + "database/sql" + "errors" + "fmt" "regexp" "strconv" - //"strings" + "strings" "time" ) +type base struct { + drivername string + dataSourceName string +} + +func (b *base) init(drivername, dataSourceName string) { + b.drivername, b.dataSourceName = drivername, dataSourceName +} + type mysql struct { + base user string passwd string net string @@ -56,7 +68,8 @@ func (cfg *mysql) parseDSN(dsn string) (err error) { return } -func (db *mysql) Init(uri string) error { +func (db *mysql) Init(drivername, uri string) error { + db.base.init(drivername, uri) return db.parseDSN(uri) } @@ -133,3 +146,149 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" return sql, args } + +func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { + args := []interface{}{db.dbname, tableName} + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + cols := make(map[string]*Column) + for _, record := range res { + col := new(Column) + for name, content := range record { + switch name { + case "COLUMN_NAME": + col.Name = string(content) + case "IS_NULLABLE": + if "YES" == string(content) { + col.Nullable = true + } + case "COLUMN_DEFAULT": + // add '' + col.Default = string(content) + case "COLUMN_TYPE": + cts := strings.Split(string(content), "(") + var len1, len2 int + if len(cts) == 2 { + lens := strings.Split(cts[1][0:len(cts[1])-1], ",") + len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) + if err != nil { + return nil, err + } + if len(lens) == 2 { + len2, err = strconv.Atoi(lens[1]) + if err != nil { + return nil, err + } + } + } + colName := cts[0] + colType := strings.ToUpper(colName) + col.Length = len1 + col.Length2 = len2 + if _, ok := sqlTypes[colType]; ok { + col.SQLType = SQLType{colType, len1, len2} + } else { + return nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) + } + case "COLUMN_KEY": + key := string(content) + if key == "PRI" { + col.IsPrimaryKey = true + } + if key == "UNI" { + //col.is + } + case "EXTRA": + extra := string(content) + if extra == "auto_increment" { + col.IsAutoIncrement = true + } + } + } + cols[col.Name] = col + } + return cols, nil +} + +func (db *mysql) GetTables() ([]*Table, error) { + args := []interface{}{db.dbname} + s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "TABLE_NAME": + table.Name = string(content) + case "ENGINE": + } + } + tables = append(tables, table) + } + return tables, nil +} + +func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { + args := []interface{}{db.dbname, tableName} + s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName, colName string + for name, content := range record { + switch name { + case "NON_UNIQUE": + if "YES" == string(content) { + indexType = IndexType + } else { + indexType = UniqueType + } + case "INDEX_NAME": + indexName = string(content) + case "COLUMN_NAME": + colName = string(content) + } + } + if indexName == "PRIMARY" { + continue + } + indexName = indexName[5+len(tableName) : len(indexName)] + + var index *Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(Index) + index.Type = indexType + index.Name = indexName + indexes[indexName] = index + } + index.AddColumn(colName) + } + return indexes, nil +} diff --git a/postgres.go b/postgres.go index 96a917e7..ce612810 100644 --- a/postgres.go +++ b/postgres.go @@ -8,6 +8,7 @@ import ( ) type postgres struct { + base dbname string } @@ -44,7 +45,9 @@ func parseOpts(name string, o values) { } } -func (db *postgres) Init(uri string) error { +func (db *postgres) Init(drivername, uri string) error { + db.base.init(drivername, uri) + o := make(values) parseOpts(uri, o) @@ -135,3 +138,43 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + " AND column_name = ?", args } + +func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { + args := []interface{}{tableName} + s := "SELECT COLUMN_NAME, column_default, is_nullable, data_type, character_maximum_length" + + " FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + + cnn, err := sql.Open(db.drivername, db.dataSourceName) + if err != nil { + return nil, err + } + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } + cols := make(map[string]*Column) + for _, record := range res { + col := new(Column) + + for name, content := range record { + switch name { + case "COLUMN_NAME": + col.Name = string(content) + case "column_default": + if strings.HasPrefix(string(content), "") { + col.IsPrimaryKey + } + } + } + } + + return nil, ErrNotImplemented +} + +func (db *postgres) GetTables() ([]*Table, error) { + return nil, ErrNotImplemented +} + +func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { + return nil, ErrNotImplemented +} diff --git a/session.go b/session.go index bba1f195..281b65ae 100644 --- a/session.go +++ b/session.go @@ -11,6 +11,8 @@ import ( "time" ) +// Struct Session keep a pointer to sql.DB and provides all execution of all +// kind of database operations. type Session struct { Db *sql.DB Engine *Engine @@ -22,6 +24,7 @@ type Session struct { IsAutoClose bool } +// Method Init reset the session as the init status. func (session *Session) Init() { session.Statement = Statement{Engine: session.Engine} session.Statement.Init() @@ -30,6 +33,7 @@ func (session *Session) Init() { session.IsAutoClose = false } +// Method Close release the connection from pool func (session *Session) Close() { defer func() { if session.Db != nil { @@ -41,56 +45,64 @@ func (session *Session) Close() { }() } +// Method Sql provides raw sql input parameter. When you have a complex SQL statement +// and cannot use Where, Id, In and etc. Methods to describe, you can use Sql. func (session *Session) Sql(querystring string, args ...interface{}) *Session { session.Statement.Sql(querystring, args...) return session } +// Method Where provides custom query condition. func (session *Session) Where(querystring string, args ...interface{}) *Session { session.Statement.Where(querystring, args...) return session } +// Method Id provides converting id as a query condition func (session *Session) Id(id int64) *Session { session.Statement.Id(id) return session } +// Method Table can input a string or pointer to struct for special a table to operate. func (session *Session) Table(tableNameOrBean interface{}) *Session { session.Statement.Table(tableNameOrBean) return session } +// Method In provides a query string like "id in (1, 2, 3)" func (session *Session) In(column string, args ...interface{}) *Session { session.Statement.In(column, args...) return session } +// Method Cols provides some columns to special func (session *Session) Cols(columns ...string) *Session { session.Statement.Cols(columns...) return session } +// Method NoAutoTime means do not automatically give created field and updated field +// the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { session.Statement.UseAutoTime = false return session } -/*func (session *Session) Trans(t string) *Session { - session.TransType = t - return session -}*/ - +// Method Limit provide limit and offset query condition func (session *Session) Limit(limit int, start ...int) *Session { session.Statement.Limit(limit, start...) return session } +// Method OrderBy provide order by query condition, the input parameter is the content +// after order by on a sql statement. func (session *Session) OrderBy(order string) *Session { session.Statement.OrderBy(order) return session } +// Method Desc provide desc order by query condition, the input parameters are columns. func (session *Session) Desc(colNames ...string) *Session { if session.Statement.OrderStr != "" { session.Statement.OrderStr += ", " @@ -101,6 +113,7 @@ func (session *Session) Desc(colNames ...string) *Session { return session } +// Method Asc provide asc order by query condition, the input parameters are columns. func (session *Session) Asc(colNames ...string) *Session { if session.Statement.OrderStr != "" { session.Statement.OrderStr += ", " @@ -111,16 +124,19 @@ func (session *Session) Asc(colNames ...string) *Session { return session } +// Method StoreEngine is only avialble mysql dialect currently func (session *Session) StoreEngine(storeEngine string) *Session { session.Statement.StoreEngine = storeEngine return session } +// Method StoreEngine is only avialble charset dialect currently func (session *Session) Charset(charset string) *Session { session.Statement.Charset = charset return session } +// Method Cascade func (session *Session) Cascade(trueOrFalse ...bool) *Session { if len(trueOrFalse) >= 1 { session.Statement.UseCascade = trueOrFalse[0] @@ -128,6 +144,8 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session { return session } +// Method NoCache ask this session do not retrieve data from cache system and +// get data from database directly. func (session *Session) NoCache() *Session { session.Statement.UseCache = false return session @@ -836,7 +854,7 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.Dialect.ColumnCheckSql(tableName, colName) + sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName) results, err := session.query(sql, args...) return len(results) > 0, err } @@ -850,7 +868,7 @@ func (session *Session) isTableExist(tableName string) (bool, error) { if session.IsAutoClose { defer session.Close() } - sql, args := session.Engine.Dialect.TableCheckSql(tableName) + sql, args := session.Engine.dialect.TableCheckSql(tableName) results, err := session.query(sql, args...) return len(results) > 0, err } @@ -870,7 +888,7 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo } else { idx = indexName(tableName, idxName) } - sql, args := session.Engine.Dialect.IndexCheckSql(tableName, idx) + sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx) results, err := session.query(sql, args...) return len(results) > 0, err } @@ -901,7 +919,7 @@ func (session *Session) addIndex(tableName, idxName string) error { defer session.Close() } //fmt.Println(idxName) - cols := session.Statement.RefTable.Indexes[idxName].GenColsStr() + cols := session.Statement.RefTable.Indexes[idxName].Cols sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) _, err = session.exec(sql, args...) return err @@ -917,7 +935,7 @@ func (session *Session) addUnique(tableName, uqeName string) error { defer session.Close() } //fmt.Println(uqeName, session.Statement.RefTable.Uniques) - cols := session.Statement.RefTable.Indexes[uqeName].GenColsStr() + cols := session.Statement.RefTable.Indexes[uqeName].Cols sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) _, err = session.exec(sql, args...) return err @@ -945,6 +963,79 @@ func (session *Session) DropAll() error { return nil } +func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { + s, err := db.Prepare(sql) + if err != nil { + return nil, err + } + defer s.Close() + res, err := s.Query(params...) + if err != nil { + return nil, err + } + defer res.Close() + fields, err := res.Columns() + if err != nil { + return nil, err + } + for res.Next() { + result := make(map[string][]byte) + var scanResultContainers []interface{} + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := res.Scan(scanResultContainers...); err != nil { + return nil, err + } + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + //fmt.Println("ignore ...", key, rawValue) + continue + } + aa := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + var str string + switch aa.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + result[key] = []byte(str) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + result[key] = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + result[key] = []byte(str) + case reflect.Slice: + switch aa.Elem().Kind() { + case reflect.Uint8: + result[key] = rawValue.Interface().([]byte) + default: + //session.Engine.LogError("Unsupported type") + } + case reflect.String: + str = vv.String() + result[key] = []byte(str) + //时间类型 + case reflect.Struct: + if aa.String() == "time.Time" { + str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") + result[key] = []byte(str) + } else { + //session.Engine.LogError("Unsupported struct type") + } + default: + //session.Engine.LogError("Unsupported type") + } + } + resultsSlice = append(resultsSlice, result) + } + return resultsSlice, nil +} + func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { for _, filter := range session.Engine.Filters { sql = filter.Do(sql, session) @@ -953,7 +1044,9 @@ func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice session.Engine.LogSQL(sql) session.Engine.LogSQL(paramStr) - s, err := session.Db.Prepare(sql) + return query(session.Db, sql, paramStr...) + + /*s, err := session.Db.Prepare(sql) if err != nil { return nil, err } @@ -1022,7 +1115,7 @@ func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice } resultsSlice = append(resultsSlice, result) } - return resultsSlice, nil + return resultsSlice, nil*/ } func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { @@ -1446,9 +1539,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { var v interface{} = id switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: v = uint(id) } pkValue.Set(reflect.ValueOf(v)) @@ -1456,6 +1549,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return id, nil } +// Method InsertOne insert only one struct into database as a record. +// The in parameter bean must a struct or a point to struct. The return +// parameter is lastInsertId and error func (session *Session) InsertOne(bean interface{}) (int64, error) { err := session.newDb() if err != nil { diff --git a/sqlite3.go b/sqlite3.go index 1d464524..b0b6924e 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1,9 +1,11 @@ package xorm type sqlite3 struct { + base } -func (db *sqlite3) Init(uri string) error { +func (db *sqlite3) Init(drivername, dataSourceName string) error { + db.base.init(drivername, dataSourceName) return nil } @@ -69,3 +71,22 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac args := []interface{}{tableName} return "SELECT name FROM sqlite_master WHERE type='table' and name = ? and sql like '%`" + colName + "`%'", args } + +func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { + /*args := []interface{}{db.dbname, tableName} + + SELECT sql FROM sqlite_master WHERE type='table' and name = 'category'; + sql := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + + return sql, args*/ + return nil, ErrNotImplemented +} + +func (db *sqlite3) GetTables() ([]*Table, error) { + return nil, ErrNotImplemented +} + +func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { + return nil, ErrNotImplemented +} diff --git a/statement.go b/statement.go index df36852b..1d7ef006 100644 --- a/statement.go +++ b/statement.go @@ -262,15 +262,15 @@ func (statement *Statement) genCreateSQL() string { sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " (" for _, colName := range statement.RefTable.ColumnsSeq { col := statement.RefTable.Columns[colName] - sql += col.String(statement.Engine) + sql += col.String(statement.Engine.dialect) sql = strings.TrimSpace(sql) sql += ", " } sql = sql[:len(sql)-2] + ")" - if statement.Engine.Dialect.SupportEngine() && statement.StoreEngine != "" { + if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" { sql += " ENGINE=" + statement.StoreEngine } - if statement.Engine.Dialect.SupportCharset() && statement.Charset != "" { + if statement.Engine.dialect.SupportCharset() && statement.Charset != "" { sql += " DEFAULT CHARSET " + statement.Charset } sql += ";" @@ -286,9 +286,11 @@ func (s *Statement) genIndexSQL() []string { tbName := s.TableName() quote := s.Engine.Quote for idxName, index := range s.RefTable.Indexes { - sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), - quote(tbName), quote(strings.Join(index.GenColsStr(), quote(",")))) - sqls = append(sqls, sql) + if index.Type == IndexType { + sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), + quote(tbName), quote(strings.Join(index.Cols, quote(",")))) + sqls = append(sqls, sql) + } } return sqls } @@ -302,9 +304,11 @@ func (s *Statement) genUniqueSQL() []string { tbName := s.TableName() quote := s.Engine.Quote for idxName, unique := range s.RefTable.Indexes { - sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)), - quote(tbName), quote(strings.Join(unique.GenColsStr(), quote(",")))) - sqls = append(sqls, sql) + if unique.Type == UniqueType { + sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)), + quote(tbName), quote(strings.Join(unique.Cols, quote(",")))) + sqls = append(sqls, sql) + } } return sqls } @@ -313,13 +317,13 @@ func (s *Statement) genDelIndexSQL() []string { var sqls []string = make([]string, 0) for idxName, index := range s.RefTable.Indexes { var rIdxName string - if index.IsUnique { + if index.Type == UniqueType { rIdxName = uniqueName(s.TableName(), idxName) - } else { + } else if index.Type == IndexType { rIdxName = indexName(s.TableName(), idxName) } sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName)) - if s.Engine.Dialect.IndexOnTable() { + if s.Engine.dialect.IndexOnTable() { sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName())) } sqls = append(sqls, sql) @@ -351,7 +355,7 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { func (s *Statement) genAddColumnStr(col *Column) (string, []interface{}) { quote := s.Engine.Quote sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()), - col.String(s.Engine)) + col.String(s.Engine.dialect)) return sql, []interface{}{} } diff --git a/table.go b/table.go index 0f9b8df4..294cb438 100644 --- a/table.go +++ b/table.go @@ -143,35 +143,57 @@ func Type2SQLType(t reflect.Type) (st SQLType) { return } +func SQLType2Type(st SQLType) reflect.Type { + switch st.Name { + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: + return reflect.TypeOf(1) + case BigInt, BigSerial: + return reflect.TypeOf(int64(1)) + case Float, Real: + return reflect.TypeOf(float32(1)) + case Double: + return reflect.TypeOf(float64(1)) + case Char, Varchar, TinyText, Text, MediumText, LongText: + return reflect.TypeOf("") + case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary: + return reflect.TypeOf([]byte{}) + case Bool: + return reflect.TypeOf(true) + case DateTime, Date, Time, TimeStamp: + return reflect.TypeOf(tm) + case Decimal, Numeric: + return reflect.TypeOf("") + default: + return reflect.TypeOf("") + } +} + const ( - TWOSIDES = iota + 1 - ONLYTODB - ONLYFROMDB + IndexType = iota + 1 + UniqueType ) type Index struct { - Name string - IsUnique bool - Cols []*Column + Name string + Type int + Cols []string } -func (index *Index) AddColumn(cols ...*Column) { +func (index *Index) AddColumn(cols ...string) { for _, col := range cols { index.Cols = append(index.Cols, col) } } -func (index *Index) GenColsStr() []string { - names := make([]string, len(index.Cols)) - for idx, col := range index.Cols { - names[idx] = col.Name - } - return names +func NewIndex(name string, indexType int) *Index { + return &Index{name, indexType, make([]string, 0)} } -func NewIndex(name string, isUnique bool) *Index { - return &Index{name, isUnique, make([]*Column, 0)} -} +const ( + TWOSIDES = iota + 1 + ONLYTODB + ONLYFROMDB +) type Column struct { Name string @@ -181,26 +203,26 @@ type Column struct { Length2 int Nullable bool Default string - Index *Index + Indexes map[string]bool IsPrimaryKey bool IsAutoIncrement bool MapType int IsCreated bool IsUpdated bool - Comment string + IsCascade bool } -func (col *Column) String(engine *Engine) string { - sql := engine.Quote(col.Name) + " " +func (col *Column) String(d dialect) string { + sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " - sql += engine.SqlType(col) + " " + sql += d.SqlType(col) + " " if col.IsPrimaryKey { sql += "PRIMARY KEY " } if col.IsAutoIncrement { - sql += engine.AutoIncrStr() + " " + sql += d.AutoIncrStr() + " " } if col.Nullable { @@ -213,9 +235,6 @@ func (col *Column) String(engine *Engine) string { sql += "DEFAULT " + col.Default + " " } - if col.Comment != "" { - sql += "COMMENT '" + col.Comment + "' " - } return sql } diff --git a/xorm.go b/xorm.go index d2dc91eb..9a6c30b5 100644 --- a/xorm.go +++ b/xorm.go @@ -10,7 +10,7 @@ import ( ) const ( - version string = "0.1.9" + version string = "0.2.0" ) func close(engine *Engine) { @@ -24,19 +24,19 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { DataSourceName: dataSourceName, Filters: make([]Filter, 0)} if driverName == SQLITE { - engine.Dialect = &sqlite3{} + engine.dialect = &sqlite3{} } else if driverName == MYSQL { - engine.Dialect = &mysql{} + engine.dialect = &mysql{} } else if driverName == POSTGRES { - engine.Dialect = &postgres{} + engine.dialect = &postgres{} engine.Filters = append(engine.Filters, &PgSeqFilter{}) engine.Filters = append(engine.Filters, &QuoteFilter{}) } else if driverName == MYMYSQL { - engine.Dialect = &mymysql{} + engine.dialect = &mymysql{} } else { return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) } - err := engine.Dialect.Init(dataSourceName) + err := engine.dialect.Init(driverName, dataSourceName) if err != nil { return nil, err } diff --git a/xorm/c++.go b/xorm/c++.go new file mode 100644 index 00000000..06ab7d0f --- /dev/null +++ b/xorm/c++.go @@ -0,0 +1 @@ +package main diff --git a/xorm/cmd.go b/xorm/cmd.go new file mode 100644 index 00000000..527a1836 --- /dev/null +++ b/xorm/cmd.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + "os" + "strings" +) + +// A Command is an implementation of a go command +// like go build or go fix. +type Command struct { + // Run runs the command. + // The args are the arguments after the command name. + Run func(cmd *Command, args []string) + + // UsageLine is the one-line usage message. + // The first word in the line is taken to be the command name. + UsageLine string + + // Short is the short description shown in the 'go help' output. + Short string + + // Long is the long message shown in the 'go help ' output. + Long string + + // Flag is a set of flags specific to this command. + Flags map[string]bool +} + +// Name returns the command's name: the first word in the usage line. +func (c *Command) Name() string { + name := c.UsageLine + i := strings.Index(name, " ") + if i >= 0 { + name = name[:i] + } + return name +} + +func (c *Command) Usage() { + fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine) + fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long)) + os.Exit(2) +} + +// Runnable reports whether the command can be run; otherwise +// it is a documentation pseudo-command such as importpath. +func (c *Command) Runnable() bool { + return c.Run != nil +} diff --git a/xorm/go.go b/xorm/go.go new file mode 100644 index 00000000..19bbd09c --- /dev/null +++ b/xorm/go.go @@ -0,0 +1,43 @@ +package main + +import ( + //"github.com/lunny/xorm" + "strings" + "xorm" +) + +func typestring(st xorm.SQLType) string { + t := xorm.SQLType2Type(st) + s := t.String() + if s == "[]uint8" { + return "[]byte" + } + return s +} + +func tag(col *xorm.Column) string { + res := make([]string, 0) + if !col.Nullable { + res = append(res, "not null") + } + if col.IsPrimaryKey { + res = append(res, "pk") + } + if col.Default != "" { + res = append(res, "default "+col.Default) + } + if col.IsAutoIncrement { + res = append(res, "autoincr") + } + if col.IsCreated { + res = append(res, "created") + } + if col.IsUpdated { + res = append(res, "updated") + } + + if len(res) > 0 { + return "`xorm:\"" + strings.Join(res, " ") + "\"`" + } + return "" +} diff --git a/xorm/install.sh b/xorm/install.sh new file mode 100755 index 00000000..e8455d2a --- /dev/null +++ b/xorm/install.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +if [ ! -f install.sh ]; then +echo 'install must be run within its container folder' 1>&2 +exit 1 +fi + +CURDIR=`pwd` +NEWPATH="$GOPATH/src/github.com/lunny/xorm/${PWD##*/}" +if [ ! -d "$NEWPATH" ]; then +ln -s $CURDIR $NEWPATH +fi + +gofmt -w $CURDIR + +cd $NEWPATH +go install ${PWD##*/} +cd $CURDIR + +echo 'finished' diff --git a/xorm/reverse.go b/xorm/reverse.go new file mode 100644 index 00000000..1c40f9e1 --- /dev/null +++ b/xorm/reverse.go @@ -0,0 +1,176 @@ +package main + +import ( + "fmt" + //"github.com/lunny/xorm" + "bytes" + _ "github.com/bylevel/pq" + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" + _ "github.com/ziutek/mymysql/godrv" + "go/format" + "io/ioutil" + "os" + "path" + "path/filepath" + "text/template" + "xorm" +) + +var CmdReverse = &Command{ + UsageLine: "reverse -m driverName datasourceName tmplpath", + Short: "reverse a db to codes", + Long: ` +according database's tables and columns to generate codes for Go, C++ and etc. +`, +} + +func init() { + CmdReverse.Run = runReverse + CmdReverse.Flags = map[string]bool{} +} + +func printReversePrompt(flag string) { +} + +type Tmpl struct { + Table *xorm.Table + Imports map[string]string + Model string +} + +func runReverse(cmd *Command, args []string) { + if len(args) < 3 { + fmt.Println("no") + return + } + + curPath, err := os.Getwd() + if err != nil { + fmt.Println(curPath) + return + } + + var genDir string + var model string + if len(args) == 4 { + genDir, err = filepath.Abs(args[3]) + if err != nil { + fmt.Println(err) + return + } + model = path.Base(genDir) + } else { + model = "model" + genDir = path.Join(curPath, model) + } + + os.MkdirAll(genDir, os.ModePerm) + + Orm, err := xorm.NewEngine(args[0], args[1]) + if err != nil { + fmt.Println(err) + return + } + + tables, err := Orm.DBMetas() + if err != nil { + fmt.Println(err) + return + } + + dir, err := filepath.Abs(args[2]) + if err != nil { + fmt.Println(curPath) + return + } + + var isMultiFile bool = true + m := &xorm.SnakeMapper{} + + filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { + if info.IsDir() { + return nil + } + + bs, err := ioutil.ReadFile(f) + if err != nil { + fmt.Println(err) + return err + } + + t := template.New(f) + t.Funcs(template.FuncMap{"Mapper": m.Table2Obj, + "Type": typestring, + "Tag": tag, + }) + + tmpl, err := t.Parse(string(bs)) + if err != nil { + fmt.Println(err) + return err + } + + var w *os.File + fileName := info.Name() + newFileName := fileName[:len(fileName)-4] + ext := path.Ext(newFileName) + + if !isMultiFile { + w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0700) + if err != nil { + fmt.Println(err) + return err + } + } + + for _, table := range tables { + // imports + imports := make(map[string]string) + for _, col := range table.Columns { + if typestring(col.SQLType) == "time.Time" { + imports["time.Time"] = "time.Time" + } + } + + if isMultiFile { + w, err = os.OpenFile(path.Join(genDir, m.Table2Obj(table.Name)+ext), os.O_RDWR|os.O_CREATE, 0700) + if err != nil { + fmt.Println(err) + return err + } + } + + newbytes := bytes.NewBufferString("") + + t := &Tmpl{Table: table, Imports: imports, Model: model} + err = tmpl.Execute(newbytes, t) + if err != nil { + fmt.Println(err) + return err + } + + tplcontent, err := ioutil.ReadAll(newbytes) + if err != nil { + fmt.Println(err) + return err + } + source, err := format.Source(tplcontent) + if err != nil { + fmt.Println(err) + return err + } + + w.WriteString(string(source)) + if isMultiFile { + w.Close() + } + } + if !isMultiFile { + w.Close() + } + + return nil + }) + +} diff --git a/xorm/templates/go/struct.go.tpl b/xorm/templates/go/struct.go.tpl new file mode 100644 index 00000000..bb195d16 --- /dev/null +++ b/xorm/templates/go/struct.go.tpl @@ -0,0 +1,11 @@ +package {{.Model}} + +import ( + "github.com/lunny/xorm" + {{range .Imports}}"{{.}}"{{end}} +) + +type {{Mapper .Table.Name}} struct { +{{range .Table.Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag .}} +{{end}} +} \ No newline at end of file diff --git a/xorm/xorm.go b/xorm/xorm.go new file mode 100644 index 00000000..ac9789cc --- /dev/null +++ b/xorm/xorm.go @@ -0,0 +1,158 @@ +package main + +import ( + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + "text/template" + "unicode" + "unicode/utf8" +) + +// +build go1.1 + +// Test that go1.1 tag above is included in builds. main.go refers to this definition. +const go11tag = true + +// Commands lists the available commands and help topics. +// The order here is the order in which they are printed by 'gopm help'. +var commands = []*Command{ + CmdReverse, +} + +func init() { + runtime.GOMAXPROCS(runtime.NumCPU()) +} + +func main() { + // Check length of arguments. + args := os.Args[1:] + if len(args) < 1 { + usage() + return + } + + // Show help documentation. + if args[0] == "help" { + help(args[1:]) + return + } + + // Check commands and run. + for _, comm := range commands { + if comm.Name() == args[0] && comm.Run != nil { + comm.Run(comm, args[1:]) + exit() + return + } + } + + fmt.Fprintf(os.Stderr, "xorm: unknown subcommand %q\nRun 'xorm help' for usage.\n", args[0]) + setExitStatus(2) + exit() +} + +var exitStatus = 0 +var exitMu sync.Mutex + +func setExitStatus(n int) { + exitMu.Lock() + if exitStatus < n { + exitStatus = n + } + exitMu.Unlock() +} + +var usageTemplate = `xorm is a database tool based xorm package. + +Usage: + + xorm command [arguments] + +The commands are: +{{range .}}{{if .Runnable}} + {{.Name | printf "%-11s"}} {{.Short}}{{end}}{{end}} + +Use "xorm help [command]" for more information about a command. + +Additional help topics: +{{range .}}{{if not .Runnable}} + {{.Name | printf "%-11s"}} {{.Short}}{{end}}{{end}} + +Use "xorm help [topic]" for more information about that topic. + +` + +var helpTemplate = `{{if .Runnable}}usage: go {{.UsageLine}} + +{{end}}{{.Long | trim}} +` + +// tmpl executes the given template text on data, writing the result to w. +func tmpl(w io.Writer, text string, data interface{}) { + t := template.New("top") + t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize}) + template.Must(t.Parse(text)) + if err := t.Execute(w, data); err != nil { + panic(err) + } +} + +func capitalize(s string) string { + if s == "" { + return s + } + r, n := utf8.DecodeRuneInString(s) + return string(unicode.ToTitle(r)) + s[n:] +} + +func printUsage(w io.Writer) { + tmpl(w, usageTemplate, commands) +} + +func usage() { + printUsage(os.Stderr) + os.Exit(2) +} + +// help implements the 'help' command. +func help(args []string) { + if len(args) == 0 { + printUsage(os.Stdout) + // not exit 2: succeeded at 'gopm help'. + return + } + if len(args) != 1 { + fmt.Fprintf(os.Stderr, "usage: xorm help command\n\nToo many arguments given.\n") + os.Exit(2) // failed at 'gopm help' + } + + arg := args[0] + + for _, cmd := range commands { + if cmd.Name() == arg { + tmpl(os.Stdout, helpTemplate, cmd) + // not exit 2: succeeded at 'gopm help cmd'. + return + } + } + + fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'xorm help'.\n", arg) + os.Exit(2) // failed at 'gopm help cmd' +} + +var atexitFuncs []func() + +func atexit(f func()) { + atexitFuncs = append(atexitFuncs, f) +} + +func exit() { + for _, f := range atexitFuncs { + f() + } + os.Exit(exitStatus) +}