From bab16dc76360ad4a01bba83dfbc7ff141dc23f59 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sun, 27 Oct 2013 09:10:20 +0800 Subject: [PATCH] added type and sequence for xorm tool;added max connect for pool(go1.2+) --- base_test.go | 30 ++++++ engine.go | 5 +- examples/maxconnect.go | 4 +- mysql.go | 18 ++-- pool.go | 5 + postgres.go | 16 +-- sqlite3.go | 12 ++- xorm/go.go | 152 ++++++++++++++++++++++++++++ xorm/templates/goxorm/struct.go.tpl | 6 +- xorm/xorm.go | 2 +- 10 files changed, 224 insertions(+), 26 deletions(-) diff --git a/base_test.go b/base_test.go index a58e0304..d5c02cd2 100644 --- a/base_test.go +++ b/base_test.go @@ -1128,6 +1128,34 @@ func testIterate(engine *Engine, t *testing.T) { } } +type StrangeName struct { + Id_t int64 `xorm:"pk autoincr"` + Name string +} + +func testStrangeName(engine *Engine, t *testing.T) { + err := engine.DropTables(new(StrangeName)) + if err != nil { + t.Error(err) + } + + err = engine.CreateTables(new(StrangeName)) + if err != nil { + t.Error(err) + } + + _, err = engine.Insert(&StrangeName{Name: "sfsfdsfds"}) + if err != nil { + t.Error(err) + } + + beans := make([]StrangeName, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + } +} + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -1210,6 +1238,8 @@ func testAll2(engine *Engine, t *testing.T) { testMetaInfo(engine, t) fmt.Println("-------------- testIterate --------------") testIterate(engine, t) + fmt.Println("-------------- testStrangeName --------------") + testStrangeName(engine, t) fmt.Println("-------------- transaction --------------") transaction(engine, t) } diff --git a/engine.go b/engine.go index a811f390..d1151d9d 100644 --- a/engine.go +++ b/engine.go @@ -32,7 +32,7 @@ type dialect interface { TableCheckSql(tableName string) (string, []interface{}) ColumnCheckSql(tableName, colName string) (string, []interface{}) - GetColumns(tableName string) (map[string]*Column, error) + GetColumns(tableName string) ([]string, map[string]*Column, error) GetTables() ([]*Table, error) GetIndexes(tableName string) (map[string]*Index, error) } @@ -189,11 +189,12 @@ func (engine *Engine) DBMetas() ([]*Table, error) { } for _, table := range tables { - cols, err := engine.dialect.GetColumns(table.Name) + colSeq, cols, err := engine.dialect.GetColumns(table.Name) if err != nil { return nil, err } table.Columns = cols + table.ColumnsSeq = colSeq indexes, err := engine.dialect.GetIndexes(table.Name) if err != nil { diff --git a/examples/maxconnect.go b/examples/maxconnect.go index b930abca..243ed544 100644 --- a/examples/maxconnect.go +++ b/examples/maxconnect.go @@ -1,15 +1,15 @@ package main import ( - //xorm "github.com/lunny/xorm" "fmt" _ "github.com/go-sql-driver/mysql" + xorm "github.com/lunny/xorm" _ "github.com/mattn/go-sqlite3" "os" //"time" //"sync/atomic" "runtime" - xorm "xorm" + //xorm "xorm" ) type User struct { diff --git a/mysql.go b/mysql.go index 42412e82..55395904 100644 --- a/mysql.go +++ b/mysql.go @@ -147,20 +147,21 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { return sql, args } -func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { +func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { args := []interface{}{db.dbname, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" cnn, err := sql.Open(db.drivername, db.dataSourceName) if err != nil { - return nil, err + return nil, nil, err } defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { - return nil, err + return nil, nil, err } cols := make(map[string]*Column) + colSeq := make([]string, 0) for _, record := range res { col := new(Column) col.Indexes = make(map[string]bool) @@ -183,12 +184,12 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { lens := strings.Split(cts[1][0:idx], ",") len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) if err != nil { - return nil, err + return nil, nil, err } if len(lens) == 2 { len2, err = strconv.Atoi(lens[1]) if err != nil { - return nil, err + return nil, nil, err } } } @@ -199,7 +200,7 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { if _, ok := sqlTypes[colType]; ok { col.SQLType = SQLType{colType, len1, len2} } else { - return nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) } case "COLUMN_KEY": key := string(content) @@ -222,8 +223,9 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { } } cols[col.Name] = col + colSeq = append(colSeq, col.Name) } - return cols, nil + return colSeq, cols, nil } func (db *mysql) GetTables() ([]*Table, error) { @@ -288,7 +290,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { if indexName == "PRIMARY" { continue } - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { indexName = indexName[5+len(tableName) : len(indexName)] } diff --git a/pool.go b/pool.go index e1e4a4fe..663fc02c 100644 --- a/pool.go +++ b/pool.go @@ -6,6 +6,7 @@ import ( "sync" //"sync/atomic" "container/list" + "reflect" "time" ) @@ -176,6 +177,10 @@ func (p *SysConnectPool) MaxIdleConns() int { // not implemented func (p *SysConnectPool) SetMaxConns(conns int) { p.maxConns = conns + // if support SetMaxOpenConns, go 1.2+, then set + if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() { + reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)}) + } //p.db.SetMaxOpenConns(conns) } diff --git a/postgres.go b/postgres.go index 51b9c85a..d98d2377 100644 --- a/postgres.go +++ b/postgres.go @@ -140,21 +140,22 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa " AND column_name = ?", args } -func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { +func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { args := []interface{}{tableName} s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" cnn, err := sql.Open(db.drivername, db.dataSourceName) if err != nil { - return nil, err + return nil, nil, err } defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { - return nil, err + return nil, nil, err } cols := make(map[string]*Column) + colSeq := make([]string, 0) for _, record := range res { col := new(Column) col.Indexes = make(map[string]bool) @@ -191,12 +192,12 @@ func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} } if _, ok := sqlTypes[col.SQLType.Name]; !ok { - return nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) } case "character_maximum_length": i, err := strconv.Atoi(string(content)) if err != nil { - return nil, errors.New("retrieve length error") + return nil, nil, errors.New("retrieve length error") } col.Length = i case "numeric_precision": @@ -209,9 +210,10 @@ func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { } } cols[col.Name] = col + colSeq = append(colSeq, col.Name) } - return cols, nil + return colSeq, cols, nil } func (db *postgres) GetTables() ([]*Table, error) { @@ -279,7 +281,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { if strings.HasSuffix(indexName, "_pkey") { continue } - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { indexName = indexName[5+len(tableName) : len(indexName)] } diff --git a/sqlite3.go b/sqlite3.go index b088bfa7..c0cb8bda 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -78,17 +78,17 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac return sql, args } -func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { +func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" cnn, err := sql.Open(db.drivername, db.dataSourceName) if err != nil { - return nil, err + return nil, nil, err } defer cnn.Close() res, err := query(cnn, s, args...) if err != nil { - return nil, err + return nil, nil, err } var sql string @@ -104,6 +104,7 @@ func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { nEnd := strings.Index(sql, ")") colCreates := strings.Split(sql[nStart+1:nEnd], ",") cols := make(map[string]*Column) + colSeq := make([]string, 0) for _, colStr := range colCreates { fields := strings.Fields(strings.TrimSpace(colStr)) col := new(Column) @@ -130,8 +131,9 @@ func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { } } cols[col.Name] = col + colSeq = append(colSeq, col.Name) } - return cols, nil + return colSeq, cols, nil } func (db *sqlite3) GetTables() ([]*Table, error) { @@ -192,7 +194,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { nNEnd := strings.Index(sql, "ON") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") //fmt.Println(indexName) - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { index.Name = indexName[5+len(tableName) : len(indexName)] } else { index.Name = indexName diff --git a/xorm/go.go b/xorm/go.go index a39b5f4a..a5f009e4 100644 --- a/xorm/go.go +++ b/xorm/go.go @@ -1,8 +1,11 @@ package main import ( + "errors" + "fmt" "github.com/lunny/xorm" "go/format" + "reflect" "strings" "text/template" ) @@ -13,12 +16,151 @@ var ( "Type": typestring, "Tag": tag, "UnTitle": unTitle, + "gt": gt, + "getCol": getCol, }, formatGo, genGoImports, } ) +var ( + errBadComparisonType = errors.New("invalid type for comparison") + errBadComparison = errors.New("incompatible types for comparison") + errNoComparison = errors.New("missing argument for comparison") +) + +type kind int + +const ( + invalidKind kind = iota + boolKind + complexKind + intKind + floatKind + integerKind + stringKind + uintKind +) + +func basicKind(v reflect.Value) (kind, error) { + switch v.Kind() { + case reflect.Bool: + return boolKind, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intKind, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintKind, nil + case reflect.Float32, reflect.Float64: + return floatKind, nil + case reflect.Complex64, reflect.Complex128: + return complexKind, nil + case reflect.String: + return stringKind, nil + } + return invalidKind, errBadComparisonType +} + +// eq evaluates the comparison a == b || a == c || ... +func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + if len(arg2) == 0 { + return false, errNoComparison + } + for _, arg := range arg2 { + v2 := reflect.ValueOf(arg) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind: + truth = v1.Bool() == v2.Bool() + case complexKind: + truth = v1.Complex() == v2.Complex() + case floatKind: + truth = v1.Float() == v2.Float() + case intKind: + truth = v1.Int() == v2.Int() + case stringKind: + truth = v1.String() == v2.String() + case uintKind: + truth = v1.Uint() == v2.Uint() + default: + panic("invalid kind") + } + if truth { + return true, nil + } + } + return false, nil +} + +// lt evaluates the comparison a < b. +func lt(arg1, arg2 interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + v2 := reflect.ValueOf(arg2) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + panic("invalid kind") + } + return truth, nil +} + +// le evaluates the comparison <= b. +func le(arg1, arg2 interface{}) (bool, error) { + // <= is < or ==. + lessThan, err := lt(arg1, arg2) + if lessThan || err != nil { + return lessThan, err + } + return eq(arg1, arg2) +} + +// gt evaluates the comparison a > b. +func gt(arg1, arg2 interface{}) (bool, error) { + // > is the inverse of <=. + lessOrEqual, err := le(arg1, arg2) + if err != nil { + return false, err + } + return !lessOrEqual, nil +} + +func getCol(cols map[string]*xorm.Column, name string) *xorm.Column { + return cols[name] +} + func formatGo(src string) (string, error) { source, err := format.Source([]byte(src)) if err != nil { @@ -94,6 +236,16 @@ func tag(table *xorm.Table, col *xorm.Column) string { res = append(res, uistr) } + nstr := col.SQLType.Name + if col.Length != 0 { + if col.Length2 != 0 { + nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2) + } else { + nstr += fmt.Sprintf("(%v)", col.Length) + } + } + res = append(res, nstr) + var tags []string if genJson { tags = append(tags, "json:\""+col.Name+"\"") diff --git a/xorm/templates/goxorm/struct.go.tpl b/xorm/templates/goxorm/struct.go.tpl index 71875f7e..91b00854 100644 --- a/xorm/templates/goxorm/struct.go.tpl +++ b/xorm/templates/goxorm/struct.go.tpl @@ -1,13 +1,17 @@ package {{.Model}} +{{$ilen := len .Imports}} +{{if gt $ilen 0}} import ( {{range .Imports}}"{{.}}"{{end}} ) +{{end}} {{range .Tables}} type {{Mapper .Name}} struct { {{$table := .}} -{{range .Columns}} {{Mapper .Name}} {{Type .}} {{Tag $table .}} +{{$columns := .Columns}} +{{range .ColumnsSeq}}{{$col := getCol $columns .}} {{Mapper $col.Name}} {{Type $col}} {{Tag $table $col}} {{end}} } diff --git a/xorm/xorm.go b/xorm/xorm.go index 2b8c486e..78583a6a 100644 --- a/xorm/xorm.go +++ b/xorm/xorm.go @@ -88,7 +88,7 @@ Use "xorm help [topic]" for more information about that topic. ` -var helpTemplate = `{{if .Runnable}}usage: go {{.UsageLine}} +var helpTemplate = `{{if .Runnable}}usage: xorm {{.UsageLine}} {{end}}{{.Long | trim}} `