added type and sequence for xorm tool;added max connect for pool(go1.2+)

This commit is contained in:
Lunny Xiao 2013-10-27 09:10:20 +08:00
parent 7a109f220f
commit bab16dc763
10 changed files with 224 additions and 26 deletions

View File

@ -1128,6 +1128,34 @@ func testIterate(engine *Engine, t *testing.T) {
} }
} }
type StrangeName struct {
Id_t int64 `xorm:"pk autoincr"`
Name string
}
func testStrangeName(engine *Engine, t *testing.T) {
err := engine.DropTables(new(StrangeName))
if err != nil {
t.Error(err)
}
err = engine.CreateTables(new(StrangeName))
if err != nil {
t.Error(err)
}
_, err = engine.Insert(&StrangeName{Name: "sfsfdsfds"})
if err != nil {
t.Error(err)
}
beans := make([]StrangeName, 0)
err = engine.Find(&beans)
if err != nil {
t.Error(err)
}
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -1210,6 +1238,8 @@ func testAll2(engine *Engine, t *testing.T) {
testMetaInfo(engine, t) testMetaInfo(engine, t)
fmt.Println("-------------- testIterate --------------") fmt.Println("-------------- testIterate --------------")
testIterate(engine, t) testIterate(engine, t)
fmt.Println("-------------- testStrangeName --------------")
testStrangeName(engine, t)
fmt.Println("-------------- transaction --------------") fmt.Println("-------------- transaction --------------")
transaction(engine, t) transaction(engine, t)
} }

View File

@ -32,7 +32,7 @@ type dialect interface {
TableCheckSql(tableName string) (string, []interface{}) TableCheckSql(tableName string) (string, []interface{})
ColumnCheckSql(tableName, colName string) (string, []interface{}) ColumnCheckSql(tableName, colName string) (string, []interface{})
GetColumns(tableName string) (map[string]*Column, error) GetColumns(tableName string) ([]string, map[string]*Column, error)
GetTables() ([]*Table, error) GetTables() ([]*Table, error)
GetIndexes(tableName string) (map[string]*Index, error) GetIndexes(tableName string) (map[string]*Index, error)
} }
@ -189,11 +189,12 @@ func (engine *Engine) DBMetas() ([]*Table, error) {
} }
for _, table := range tables { for _, table := range tables {
cols, err := engine.dialect.GetColumns(table.Name) colSeq, cols, err := engine.dialect.GetColumns(table.Name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
table.Columns = cols table.Columns = cols
table.ColumnsSeq = colSeq
indexes, err := engine.dialect.GetIndexes(table.Name) indexes, err := engine.dialect.GetIndexes(table.Name)
if err != nil { if err != nil {

View File

@ -1,15 +1,15 @@
package main package main
import ( import (
//xorm "github.com/lunny/xorm"
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
xorm "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
//"time" //"time"
//"sync/atomic" //"sync/atomic"
"runtime" "runtime"
xorm "xorm" //xorm "xorm"
) )
type User struct { type User struct {

View File

@ -147,20 +147,21 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbname, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
@ -183,12 +184,12 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) {
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.Atoi(lens[1])
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
} }
} }
@ -199,7 +200,7 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) {
if _, ok := sqlTypes[colType]; ok { if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2} col.SQLType = SQLType{colType, len1, len2}
} else { } else {
return nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
} }
case "COLUMN_KEY": case "COLUMN_KEY":
key := string(content) key := string(content)
@ -222,8 +223,9 @@ func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) {
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name)
} }
return cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*Table, error) { func (db *mysql) GetTables() ([]*Table, error) {
@ -288,7 +290,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
if indexName == "PRIMARY" { if indexName == "PRIMARY" {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }

View File

@ -6,6 +6,7 @@ import (
"sync" "sync"
//"sync/atomic" //"sync/atomic"
"container/list" "container/list"
"reflect"
"time" "time"
) )
@ -176,6 +177,10 @@ func (p *SysConnectPool) MaxIdleConns() int {
// not implemented // not implemented
func (p *SysConnectPool) SetMaxConns(conns int) { func (p *SysConnectPool) SetMaxConns(conns int) {
p.maxConns = conns p.maxConns = conns
// if support SetMaxOpenConns, go 1.2+, then set
if reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").IsValid() {
reflect.ValueOf(p.db).MethodByName("SetMaxOpenConns").Call([]reflect.Value{reflect.ValueOf(conns)})
}
//p.db.SetMaxOpenConns(conns) //p.db.SetMaxOpenConns(conns)
} }

View File

@ -140,21 +140,22 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
@ -191,12 +192,12 @@ func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) {
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
} }
if _, ok := sqlTypes[col.SQLType.Name]; !ok { if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
} }
case "character_maximum_length": case "character_maximum_length":
i, err := strconv.Atoi(string(content)) i, err := strconv.Atoi(string(content))
if err != nil { if err != nil {
return nil, errors.New("retrieve length error") return nil, nil, errors.New("retrieve length error")
} }
col.Length = i col.Length = i
case "numeric_precision": case "numeric_precision":
@ -209,9 +210,10 @@ func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) {
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name)
} }
return cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
@ -279,7 +281,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }

View File

@ -78,17 +78,17 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac
return sql, args return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
var sql string var sql string
@ -104,6 +104,7 @@ func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) {
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",") colCreates := strings.Split(sql[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(Column)
@ -130,8 +131,9 @@ func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) {
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name)
} }
return cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*Table, error) {
@ -192,7 +194,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName) //fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "QUE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else { } else {
index.Name = indexName index.Name = indexName

View File

@ -1,8 +1,11 @@
package main package main
import ( import (
"errors"
"fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"go/format" "go/format"
"reflect"
"strings" "strings"
"text/template" "text/template"
) )
@ -13,12 +16,151 @@ var (
"Type": typestring, "Type": typestring,
"Tag": tag, "Tag": tag,
"UnTitle": unTitle, "UnTitle": unTitle,
"gt": gt,
"getCol": getCol,
}, },
formatGo, formatGo,
genGoImports, genGoImports,
} }
) )
var (
errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison")
)
type kind int
const (
invalidKind kind = iota
boolKind
complexKind
intKind
floatKind
integerKind
stringKind
uintKind
)
func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() {
case reflect.Bool:
return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil
case reflect.Float32, reflect.Float64:
return floatKind, nil
case reflect.Complex64, reflect.Complex128:
return complexKind, nil
case reflect.String:
return stringKind, nil
}
return invalidKind, errBadComparisonType
}
// eq evaluates the comparison a == b || a == c || ...
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
if len(arg2) == 0 {
return false, errNoComparison
}
for _, arg := range arg2 {
v2 := reflect.ValueOf(arg)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
if k1 != k2 {
return false, errBadComparison
}
truth := false
switch k1 {
case boolKind:
truth = v1.Bool() == v2.Bool()
case complexKind:
truth = v1.Complex() == v2.Complex()
case floatKind:
truth = v1.Float() == v2.Float()
case intKind:
truth = v1.Int() == v2.Int()
case stringKind:
truth = v1.String() == v2.String()
case uintKind:
truth = v1.Uint() == v2.Uint()
default:
panic("invalid kind")
}
if truth {
return true, nil
}
}
return false, nil
}
// lt evaluates the comparison a < b.
func lt(arg1, arg2 interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1)
if err != nil {
return false, err
}
v2 := reflect.ValueOf(arg2)
k2, err := basicKind(v2)
if err != nil {
return false, err
}
if k1 != k2 {
return false, errBadComparison
}
truth := false
switch k1 {
case boolKind, complexKind:
return false, errBadComparisonType
case floatKind:
truth = v1.Float() < v2.Float()
case intKind:
truth = v1.Int() < v2.Int()
case stringKind:
truth = v1.String() < v2.String()
case uintKind:
truth = v1.Uint() < v2.Uint()
default:
panic("invalid kind")
}
return truth, nil
}
// le evaluates the comparison <= b.
func le(arg1, arg2 interface{}) (bool, error) {
// <= is < or ==.
lessThan, err := lt(arg1, arg2)
if lessThan || err != nil {
return lessThan, err
}
return eq(arg1, arg2)
}
// gt evaluates the comparison a > b.
func gt(arg1, arg2 interface{}) (bool, error) {
// > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2)
if err != nil {
return false, err
}
return !lessOrEqual, nil
}
func getCol(cols map[string]*xorm.Column, name string) *xorm.Column {
return cols[name]
}
func formatGo(src string) (string, error) { func formatGo(src string) (string, error) {
source, err := format.Source([]byte(src)) source, err := format.Source([]byte(src))
if err != nil { if err != nil {
@ -94,6 +236,16 @@ func tag(table *xorm.Table, col *xorm.Column) string {
res = append(res, uistr) res = append(res, uistr)
} }
nstr := col.SQLType.Name
if col.Length != 0 {
if col.Length2 != 0 {
nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2)
} else {
nstr += fmt.Sprintf("(%v)", col.Length)
}
}
res = append(res, nstr)
var tags []string var tags []string
if genJson { if genJson {
tags = append(tags, "json:\""+col.Name+"\"") tags = append(tags, "json:\""+col.Name+"\"")

View File

@ -1,13 +1,17 @@
package {{.Model}} package {{.Model}}
{{$ilen := len .Imports}}
{{if gt $ilen 0}}
import ( import (
{{range .Imports}}"{{.}}"{{end}} {{range .Imports}}"{{.}}"{{end}}
) )
{{end}}
{{range .Tables}} {{range .Tables}}
type {{Mapper .Name}} struct { type {{Mapper .Name}} struct {
{{$table := .}} {{$table := .}}
{{range .Columns}} {{Mapper .Name}} {{Type .}} {{Tag $table .}} {{$columns := .Columns}}
{{range .ColumnsSeq}}{{$col := getCol $columns .}} {{Mapper $col.Name}} {{Type $col}} {{Tag $table $col}}
{{end}} {{end}}
} }

View File

@ -88,7 +88,7 @@ Use "xorm help [topic]" for more information about that topic.
` `
var helpTemplate = `{{if .Runnable}}usage: go {{.UsageLine}} var helpTemplate = `{{if .Runnable}}usage: xorm {{.UsageLine}}
{{end}}{{.Long | trim}} {{end}}{{.Long | trim}}
` `