Merge remote-tracking branch 'upstream/dev' into dev

This commit is contained in:
Nash Tsai 2014-01-09 15:41:55 +08:00
commit cf73ec9f2d
57 changed files with 6359 additions and 5840 deletions

View File

@ -1,17 +1,17 @@
[中文](https://github.com/lunny/xorm/blob/master/README_CN.md) [中文](https://github.com/lunny/xorm/blob/master/README_CN.md)
Xorm is a simple and powerful ORM for Go. Xorm is a simple and powerful ORM for Go.
[![Build Status](https://drone.io/github.com/lunny/xorm/status.png)](https://drone.io/github.com/lunny/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/lunny/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") [![Build Status](https://drone.io/github.com/lunny/xorm/status.png)](https://drone.io/github.com/lunny/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/lunny/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge")
# Features # Features
* Struct <-> Table Mapping Support * Struct <-> Table Mapping Support
* Chainable APIs * Chainable APIs
* Transaction Support * Transaction Support
* Both ORM and raw SQL operation Support * Both ORM and raw SQL operation Support
* Sync database sechmea Support * Sync database sechmea Support
@ -24,26 +24,26 @@ Xorm is a simple and powerful ORM for Go.
* Optimistic Locking support * Optimistic Locking support
# Drivers Support # Drivers Support
Drivers for Go's sql package which currently support database/sql includes: Drivers for Go's sql package which currently support database/sql includes:
* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) * Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/lib/pq](https://github.com/lib/pq) * Postgres: [github.com/lib/pq](https://github.com/lib/pq)
* MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc)
# Changelog # Changelog
* **v0.3.1** * **v0.3.1**
Features: Features:
* Support MSSQL DB via ODBC driver ([github.com/lunny/godbc](https://github.com/lunny/godbc)); * Support MSSQL DB via ODBC driver ([github.com/lunny/godbc](https://github.com/lunny/godbc));
* Composite Key, using multiple pk xorm tag * Composite Key, using multiple pk xorm tag
* Added Row() API as alternative to Iterate() API for traversing result set, provide similar usages to sql.Rows type * Added Row() API as alternative to Iterate() API for traversing result set, provide similar usages to sql.Rows type
@ -54,22 +54,22 @@ Drivers for Go's sql package which currently support database/sql includes:
* Allowed int/int32/int64/uint/uint32/uint64/string as Primary Key type * Allowed int/int32/int64/uint/uint32/uint64/string as Primary Key type
* Performance improvement for Get()/Find()/Iterate() * Performance improvement for Get()/Find()/Iterate()
[More changelogs ...](https://github.com/lunny/xorm/blob/master/docs/Changelog.md) [More changelogs ...](https://github.com/lunny/xorm/blob/master/docs/Changelog.md)
# Installation # Installation
If you have [gopm](https://github.com/gpmgo/gopm) installed, If you have [gopm](https://github.com/gpmgo/gopm) installed,
gopm get github.com/lunny/xorm gopm get github.com/lunny/xorm
Or Or
go get github.com/lunny/xorm go get github.com/lunny/xorm
# Documents # Documents
* [GoDoc](http://godoc.org/github.com/lunny/xorm) * [GoDoc](http://godoc.org/github.com/lunny/xorm)
* [GoWalker](http://gowalker.org/github.com/lunny/xorm) * [GoWalker](http://gowalker.org/github.com/lunny/xorm)
* [Quick Start](https://github.com/lunny/xorm/blob/master/docs/QuickStartEn.md) * [Quick Start](https://github.com/lunny/xorm/blob/master/docs/QuickStartEn.md)
@ -82,12 +82,12 @@ Or
* [Godaily](http://godaily.org) - [github.com/govc/godaily](http://github.com/govc/godaily) * [Godaily](http://godaily.org) - [github.com/govc/godaily](http://github.com/govc/godaily)
* [Very Hour](http://veryhour.com/) * [Very Hour](http://veryhour.com/)
# Todo # Todo
[Todo List](https://trello.com/b/IHsuAnhk/xorm) [Todo List](https://trello.com/b/IHsuAnhk/xorm)
# Discuss # Discuss
Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm)
@ -97,9 +97,9 @@ Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xor
If you want to pull request, please see [CONTRIBUTING](https://github.com/lunny/xorm/blob/master/CONTRIBUTING.md) If you want to pull request, please see [CONTRIBUTING](https://github.com/lunny/xorm/blob/master/CONTRIBUTING.md)
* [Lunny](https://github.com/lunny) * [Lunny](https://github.com/lunny)
* [Nashtsai](https://github.com/nashtsai) * [Nashtsai](https://github.com/nashtsai)
# LICENSE # LICENSE
BSD License BSD License
[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)

View File

@ -1,7 +1,7 @@
# xorm # xorm
[English](https://github.com/lunny/xorm/blob/master/README.md) [English](https://github.com/lunny/xorm/blob/master/README.md)
xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。
[![Build Status](https://drone.io/github.com/lunny/xorm/status.png)](https://drone.io/github.com/lunny/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/lunny/xorm) [![Build Status](https://drone.io/github.com/lunny/xorm/status.png)](https://drone.io/github.com/lunny/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/lunny/xorm)
@ -27,11 +27,11 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* 支持记录版本(即乐观锁) * 支持记录版本(即乐观锁)
## 驱动支持 ## 驱动支持
目前支持的Go数据库驱动和对应的数据库如下 目前支持的Go数据库驱动和对应的数据库如下
* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) * Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
@ -42,7 +42,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
## 更新日志 ## 更新日志
* **v0.3.1** * **v0.3.1**
新特性: 新特性:
* 支持 MSSQL DB 通过 ODBC 驱动 ([github.com/lunny/godbc](https://github.com/lunny/godbc)); * 支持 MSSQL DB 通过 ODBC 驱动 ([github.com/lunny/godbc](https://github.com/lunny/godbc));
@ -53,11 +53,11 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
改进: 改进:
* 允许 int/int32/int64/uint/uint32/uint64/string 作为主键类型 * 允许 int/int32/int64/uint/uint32/uint64/string 作为主键类型
* 查询函数 Get()/Find()/Iterate() 在性能上的改进 * 查询函数 Get()/Find()/Iterate() 在性能上的改进
[更多更新日志...](https://github.com/lunny/xorm/blob/master/docs/ChangelogCN.md) [更多更新日志...](https://github.com/lunny/xorm/blob/master/docs/ChangelogCN.md)
## 安装 ## 安装
推荐使用 [gopm](https://github.com/gpmgo/gopm) 进行安装: 推荐使用 [gopm](https://github.com/gpmgo/gopm) 进行安装:
@ -66,10 +66,10 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
或者您也可以使用go工具进行安装 或者您也可以使用go工具进行安装
go get github.com/lunny/xorm go get github.com/lunny/xorm
## 文档 ## 文档
* [快速开始](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md) * [快速开始](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md)
* [GoWalker代码文档](http://gowalker.org/github.com/lunny/xorm) * [GoWalker代码文档](http://gowalker.org/github.com/lunny/xorm)
@ -85,8 +85,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* [Godaily](http://godaily.org) - [github.com/govc/godaily](http://github.com/govc/godaily) * [Godaily](http://godaily.org) - [github.com/govc/godaily](http://github.com/govc/godaily)
* [Very Hour](http://veryhour.com/) * [Very Hour](http://veryhour.com/)
## Todo ## Todo
[开发计划](https://trello.com/b/IHsuAnhk/xorm) [开发计划](https://trello.com/b/IHsuAnhk/xorm)
@ -101,8 +101,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作
* [Lunny](https://github.com/lunny) * [Lunny](https://github.com/lunny)
* [Nashtsai](https://github.com/nashtsai) * [Nashtsai](https://github.com/nashtsai)
## LICENSE ## LICENSE
BSD License BSD License
[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)

View File

@ -6,6 +6,8 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/lunny/xorm/core"
) )
/* /*
@ -1850,7 +1852,7 @@ func testMetaInfo(engine *Engine, t *testing.T) {
for _, table := range tables { for _, table := range tables {
fmt.Println(table.Name) fmt.Println(table.Name)
for _, col := range table.Columns { for _, col := range table.Columns() {
fmt.Println(col.String(engine.dialect)) fmt.Println(col.String(engine.dialect))
} }
@ -3175,7 +3177,9 @@ func testPointerData(engine *Engine, t *testing.T) {
// using instance type should just work too // using instance type should just work too
nullData2Get := NullData2{} nullData2Get := NullData2{}
has, err = engine.Table("null_data").Id(nullData.Id).Get(&nullData2Get) tableName := engine.tableMapper.Obj2Table("NullData")
has, err = engine.Table(tableName).Id(nullData.Id).Get(&nullData2Get)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -3525,8 +3529,8 @@ func testNullValue(engine *Engine, t *testing.T) {
// skipped postgres test due to postgres driver doesn't read time.Time's timzezone info when stored in the db // skipped postgres test due to postgres driver doesn't read time.Time's timzezone info when stored in the db
// mysql and sqlite3 seem have done this correctly by storing datatime in UTC timezone, I think postgres driver // mysql and sqlite3 seem have done this correctly by storing datatime in UTC timezone, I think postgres driver
// prefer using timestamp with timezone to sovle the issue // prefer using timestamp with timezone to sovle the issue
if engine.DriverName != POSTGRES && engine.DriverName != MYMYSQL && if engine.DriverName != core.POSTGRES && engine.DriverName != "mymysql" &&
engine.DriverName != MYSQL { engine.DriverName != core.MYSQL {
if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() { if (*nullDataGet.TimePtr).Unix() != (*nullDataUpdate.TimePtr).Unix() {
t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr))) t.Error(errors.New(fmt.Sprintf("inserted value unmatch: [%v]:[%v]", *nullDataGet.TimePtr, *nullDataUpdate.TimePtr)))
} else { } else {
@ -3540,7 +3544,9 @@ func testNullValue(engine *Engine, t *testing.T) {
// update to null values // update to null values
nullDataUpdate = NullData{} nullDataUpdate = NullData{}
cnt, err = engine.Id(nullData.Id).Cols("string_ptr").Update(&nullDataUpdate) string_ptr := engine.columnMapper.Obj2Table("StringPtr")
cnt, err = engine.Id(nullData.Id).Cols(string_ptr).Update(&nullDataUpdate)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -3896,3 +3902,11 @@ func testAll3(engine *Engine, t *testing.T) {
fmt.Println("-------------- testStringPK --------------") fmt.Println("-------------- testStringPK --------------")
testStringPK(engine, t) testStringPK(engine, t)
} }
func testAllSnakeMapper(engine *Engine, t *testing.T) {
}
func testAllSameMapper(engine *Engine, t *testing.T) {
}

113
core/column.go Normal file
View File

@ -0,0 +1,113 @@
package core
import (
"fmt"
"reflect"
"strings"
)
const (
TWOSIDES = iota + 1
ONLYTODB
ONLYFROMDB
)
// database column
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Length2 int
Nullable bool
Default string
Indexes map[string]bool
IsPrimaryKey bool
IsAutoIncrement bool
MapType int
IsCreated bool
IsUpdated bool
IsCascade bool
IsVersion bool
fieldPath []string
}
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column {
return &Column{name, fieldName, sqlType, len1, len2, nullable, "", make(map[string]bool), false, false,
TWOSIDES, false, false, false, false, nil}
}
// generate column description string according dialect
func (col *Column) String(d Dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
func (col *Column) StringNoPk(d Dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " "
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
// return col's filed of struct's value
func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
dataStruct := reflect.Indirect(reflect.ValueOf(bean))
return col.ValueOfV(&dataStruct)
}
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var fieldValue reflect.Value
var err error
if col.fieldPath == nil {
col.fieldPath = strings.Split(col.FieldName, ".")
}
if len(col.fieldPath) == 1 {
fieldValue = dataStruct.FieldByName(col.FieldName)
} else if len(col.fieldPath) == 2 {
parentField := dataStruct.FieldByName(col.fieldPath[0])
if parentField.IsValid() {
fieldValue = parentField.FieldByName(col.fieldPath[1])
} else {
err = fmt.Errorf("field %v is not valid", col.FieldName)
}
} else {
err = fmt.Errorf("Unsupported mutliderive %v", col.FieldName)
}
if err != nil {
return nil, err
}
return &fieldValue, nil
}

8
core/converstion.go Normal file
View File

@ -0,0 +1,8 @@
package core
// Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database.
type Conversion interface {
FromDB([]byte) error
ToDB() ([]byte, error)
}

44
core/db.go Normal file
View File

@ -0,0 +1,44 @@
package core
import (
"database/sql"
"reflect"
)
type DB struct {
*sql.DB
}
func Open(driverName, dataSourceName string) (*DB, error) {
db, err := sql.Open(driverName, dataSourceName)
return &DB{db}, err
}
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
rows, err := db.DB.Query(query, args...)
return &Rows{rows}, err
}
type Rows struct {
*sql.Rows
}
func (rs *Rows) Scan(dest ...interface{}) error {
newDest := make([]interface{}, 0)
for _, s := range dest {
vv := reflect.ValueOf(s)
switch vv.Kind() {
case reflect.Ptr:
vvv := vv.Elem()
if vvv.Kind() == reflect.Struct {
for j := 0; j < vvv.NumField(); j++ {
newDest = append(newDest, vvv.FieldByIndex([]int{j}).Addr().Interface())
}
} else {
newDest = append(newDest, s)
}
}
}
return rs.Rows.Scan(newDest...)
}

53
core/db_test.go Normal file
View File

@ -0,0 +1,53 @@
package core
import (
"fmt"
"testing"
_ "github.com/mattn/go-sqlite3"
)
var (
createTableSqlite3 = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, `title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL);"
)
type User struct {
Id int64
Name string
Title string
Age float32
Alias string
NickName string
}
func TestQuery(t *testing.T) {
db, err := Open("sqlite3", "./test.db")
if err != nil {
t.Error(err)
}
_, err = db.Exec(createTableSqlite3)
if err != nil {
t.Error(err)
}
_, err = db.Exec("insert into user (name, title, age, alias, nick_name) values (?,?,?,?,?)",
"xlw", "tester", 1.2, "lunny", "lunny xiao")
if err != nil {
t.Error(err)
}
rows, err := db.Query("select * from user")
if err != nil {
t.Error(err)
}
for rows.Next() {
var user User
err = rows.Scan(&user)
if err != nil {
t.Error(err)
}
fmt.Println(user)
}
}

144
core/dialect.go Normal file
View File

@ -0,0 +1,144 @@
package core
import (
"strings"
"time"
)
type dbType string
type Uri struct {
DbType dbType
Proto string
Host string
Port string
DbName string
User string
Passwd string
Charset string
Laddr string
Raddr string
Timeout time.Duration
}
// a dialect is a driver's wrapper
type Dialect interface {
Init(*Uri, string, string) error
URI() *Uri
DBType() dbType
SqlType(t *Column) string
SupportInsertMany() bool
QuoteStr() string
AutoIncrStr() string
SupportEngine() bool
SupportCharset() bool
IndexOnTable() bool
IndexCheckSql(tableName, idxName string) (string, []interface{})
TableCheckSql(tableName string) (string, []interface{})
ColumnCheckSql(tableName, colName string) (string, []interface{})
GetColumns(tableName string) ([]string, map[string]*Column, error)
GetTables() ([]*Table, error)
GetIndexes(tableName string) (map[string]*Index, error)
CreateTableSql(table *Table, tableName, storeEngine, charset string) string
Filters() []Filter
DriverName() string
DataSourceName() string
}
type Base struct {
dialect Dialect
driverName string
dataSourceName string
*Uri
}
func (b *Base) Init(dialect Dialect, uri *Uri, drivername, dataSourceName string) error {
b.dialect = dialect
b.driverName, b.dataSourceName = drivername, dataSourceName
b.Uri = uri
return nil
}
func (b *Base) URI() *Uri {
return b.Uri
}
func (b *Base) DBType() dbType {
return b.Uri.DbType
}
func (b *Base) DriverName() string {
return b.driverName
}
func (b *Base) DataSourceName() string {
return b.dataSourceName
}
func (b *Base) Quote(c string) string {
return b.dialect.QuoteStr() + c + b.dialect.QuoteStr()
}
func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset string) string {
var sql string
sql = "CREATE TABLE IF NOT EXISTS "
if tableName == "" {
tableName = table.Name
}
sql += b.Quote(tableName) + " ("
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(b.dialect)
} else {
sql += col.StringNoPk(b.dialect)
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += strings.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
if b.dialect.SupportEngine() && storeEngine != "" {
sql += " ENGINE=" + storeEngine
}
if b.dialect.SupportCharset() {
if charset == "" {
charset = b.dialect.URI().Charset
}
sql += " DEFAULT CHARSET " + charset
}
sql += ";"
return sql
}
var (
dialects = map[dbType]Dialect{}
)
func RegisterDialect(dbName dbType, dialect Dialect) {
if dialect == nil {
panic("core: Register dialect is nil")
}
if _, dup := dialects[dbName]; dup {
panic("core: Register called twice for dialect " + dbName)
}
dialects[dbName] = dialect
}
func QueryDialect(dbName dbType) Dialect {
return dialects[dbName]
}

23
core/driver.go Normal file
View File

@ -0,0 +1,23 @@
package core
type driver interface {
Parse(string, string) (*Uri, error)
}
var (
drivers = map[string]driver{}
)
func RegisterDriver(driverName string, driver driver) {
if driver == nil {
panic("core: Register driver is nil")
}
if _, dup := drivers[driverName]; dup {
panic("core: Register called twice for driver " + driverName)
}
drivers[driverName] = driver
}
func QueryDriver(driverName string) driver {
return drivers[driverName]
}

42
core/filter.go Normal file
View File

@ -0,0 +1,42 @@
package core
import "strings"
// Filter is an interface to filter SQL
type Filter interface {
Do(sql string, dialect Dialect, table *Table) string
}
// QuoteFilter filter SQL replace ` to database's own quote character
type QuoteFilter struct {
}
func (s *QuoteFilter) Do(sql string, dialect Dialect, table *Table) string {
return strings.Replace(sql, "`", dialect.QuoteStr(), -1)
}
// IdFilter filter SQL replace (id) to primary key column name
type IdFilter struct {
}
type Quoter struct {
dialect Dialect
}
func NewQuoter(dialect Dialect) *Quoter {
return &Quoter{dialect}
}
func (q *Quoter) Quote(content string) string {
return q.dialect.QuoteStr() + content + q.dialect.QuoteStr()
}
func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string {
quoter := NewQuoter(dialect)
if table != nil && len(table.PrimaryKeys) == 1 {
sql = strings.Replace(sql, "`(id)`", quoter.Quote(table.PrimaryKeys[0]), -1)
sql = strings.Replace(sql, quoter.Quote("(id)"), quoter.Quote(table.PrimaryKeys[0]), -1)
return strings.Replace(sql, "(id)", quoter.Quote(table.PrimaryKeys[0]), -1)
}
return sql
}

25
core/index.go Normal file
View File

@ -0,0 +1,25 @@
package core
const (
IndexType = iota + 1
UniqueType
)
// database index
type Index struct {
Name string
Type int
Cols []string
}
// add columns which will be composite index
func (index *Index) AddColumn(cols ...string) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
// new an index
func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)}
}

94
core/table.go Normal file
View File

@ -0,0 +1,94 @@
package core
import (
"reflect"
"strings"
)
// database table
type Table struct {
Name string
Type reflect.Type
columnsSeq []string
columns map[string]*Column
Indexes map[string]*Index
PrimaryKeys []string
AutoIncrement string
Created map[string]bool
Updated string
Version string
}
func (table *Table) Columns() map[string]*Column {
return table.columns
}
func (table *Table) ColumnsSeq() []string {
return table.columnsSeq
}
func NewEmptyTable() *Table {
return &Table{columnsSeq: make([]string, 0),
columns: make(map[string]*Column),
Indexes: make(map[string]*Index),
Created: make(map[string]bool),
PrimaryKeys: make([]string, 0),
}
}
func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t,
columnsSeq: make([]string, 0),
columns: make(map[string]*Column),
Indexes: make(map[string]*Index),
Created: make(map[string]bool),
PrimaryKeys: make([]string, 0),
}
}
func (table *Table) GetColumn(name string) *Column {
return table.columns[strings.ToLower(name)]
}
// if has primary key, return column
func (table *Table) PKColumns() []*Column {
columns := make([]*Column, 0)
for _, name := range table.PrimaryKeys {
columns = append(columns, table.GetColumn(name))
}
return columns
}
func (table *Table) AutoIncrColumn() *Column {
return table.GetColumn(table.AutoIncrement)
}
func (table *Table) VersionColumn() *Column {
return table.GetColumn(table.Version)
}
// add a column to table
func (table *Table) AddColumn(col *Column) {
table.columnsSeq = append(table.columnsSeq, col.Name)
table.columns[strings.ToLower(col.Name)] = col
if col.IsPrimaryKey {
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
}
if col.IsAutoIncrement {
table.AutoIncrement = col.Name
}
if col.IsCreated {
table.Created[col.Name] = true
}
if col.IsUpdated {
table.Updated = col.Name
}
if col.IsVersion {
table.Version = col.Name
}
}
// add an index or an unique to table
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
}

235
core/type.go Normal file
View File

@ -0,0 +1,235 @@
package core
import (
"reflect"
"sort"
"strings"
"time"
)
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MSSQL = "mssql"
ORACLE = "oracle"
)
// xorm SQL types
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText
}
func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
}
var (
Bit = "BIT"
TinyInt = "TINYINT"
SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT"
Int = "INT"
Integer = "INTEGER"
BigInt = "BIGINT"
Char = "CHAR"
Varchar = "VARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Date = "DATE"
DateTime = "DATETIME"
Time = "TIME"
TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ"
Decimal = "DECIMAL"
Numeric = "NUMERIC"
Real = "REAL"
Float = "FLOAT"
Double = "DOUBLE"
Binary = "BINARY"
VarBinary = "VARBINARY"
TinyBlob = "TINYBLOB"
Blob = "BLOB"
MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB"
Bytea = "BYTEA"
Bool = "BOOL"
Serial = "SERIAL"
BigSerial = "BIGSERIAL"
SqlTypes = map[string]bool{
Bit: true,
TinyInt: true,
SmallInt: true,
MediumInt: true,
Int: true,
Integer: true,
BigInt: true,
Char: true,
Varchar: true,
TinyText: true,
Text: true,
MediumText: true,
LongText: true,
Date: true,
DateTime: true,
Time: true,
TimeStamp: true,
TimeStampz: true,
Decimal: true,
Numeric: true,
Binary: true,
VarBinary: true,
Real: true,
Float: true,
Double: true,
TinyBlob: true,
Blob: true,
MediumBlob: true,
LongBlob: true,
Bytea: true,
Bool: true,
Serial: true,
BigSerial: true,
}
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
)
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparision
var (
c_EMPTY_STRING string
c_BOOL_DEFAULT bool
c_BYTE_DEFAULT byte
c_COMPLEX64_DEFAULT complex64
c_COMPLEX128_DEFAULT complex128
c_FLOAT32_DEFAULT float32
c_FLOAT64_DEFAULT float64
c_INT64_DEFAULT int64
c_UINT64_DEFAULT uint64
c_INT32_DEFAULT int32
c_UINT32_DEFAULT uint32
c_INT16_DEFAULT int16
c_UINT16_DEFAULT uint16
c_INT8_DEFAULT int8
c_UINT8_DEFAULT uint8
c_INT_DEFAULT int
c_UINT_DEFAULT uint
c_TIME_DEFAULT time.Time
)
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0}
case reflect.Float32:
st = SQLType{Float, 0, 0}
case reflect.Float64:
st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) {
st = SQLType{Blob, 0, 0}
} else {
st = SQLType{Text, 0, 0}
}
case reflect.Bool:
st = SQLType{Bool, 0, 0}
case reflect.String:
st = SQLType{Varchar, 255, 0}
case reflect.Struct:
if t == reflect.TypeOf(c_TIME_DEFAULT) {
st = SQLType{DateTime, 0, 0}
} else {
// TODO need to handle association struct
st = SQLType{Text, 0, 0}
}
case reflect.Ptr:
st, _ = ptrType2SQLType(t)
default:
st = SQLType{Text, 0, 0}
}
return
}
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
has = true
switch t {
case reflect.TypeOf(&c_EMPTY_STRING):
st = SQLType{Varchar, 255, 0}
return
case reflect.TypeOf(&c_BOOL_DEFAULT):
st = SQLType{Bool, 0, 0}
case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
st = SQLType{Varchar, 64, 0}
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
st = SQLType{Float, 0, 0}
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
st = SQLType{Double, 0, 0}
case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
st = SQLType{BigInt, 0, 0}
case reflect.TypeOf(&c_TIME_DEFAULT):
st = SQLType{DateTime, 0, 0}
case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
st = SQLType{Int, 0, 0}
default:
has = false
}
return
}
// default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name)
switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1)
case BigInt, BigSerial:
return reflect.TypeOf(int64(1))
case Float, Real:
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, Varchar, TinyText, Text, MediumText, LongText:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{})
case Bool:
return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz:
return reflect.TypeOf(c_TIME_DEFAULT)
case Decimal, Numeric:
return reflect.TypeOf("")
default:
return reflect.TypeOf("")
}
}

View File

@ -1,46 +1,28 @@
package xorm package dialects
import ( import (
//"crypto/tls" //"crypto/tls"
"database/sql"
"errors" "errors"
"fmt" "fmt"
//"regexp" //"regexp"
"strconv" "strconv"
"strings" "strings"
//"time" //"time"
. "github.com/lunny/xorm/core"
) )
func init() {
RegisterDialect("mssql", &mssql{})
}
type mssql struct { type mssql struct {
base Base
quoteFilter Filter
} }
type odbcParser struct { func (db *mssql) Init(uri *Uri, drivername, dataSourceName string) error {
} return db.Base.Init(db, uri, drivername, dataSourceName)
func (p *odbcParser) parse(driverName, dataSourceName string) (*uri, error) {
kv := strings.Split(dataSourceName, ";")
var dbName string
for _, c := range kv {
vv := strings.Split(strings.TrimSpace(c), "=")
if len(vv) == 2 {
switch strings.ToLower(vv[0]) {
case "database":
dbName = vv[1]
}
}
}
if dbName == "" {
return nil, errors.New("no db name provided")
}
return &uri{dbName: dbName, dbType: MSSQL}, nil
}
func (db *mssql) Init(drivername, uri string) error {
db.quoteFilter = &QuoteFilter{}
return db.base.init(&odbcParser{}, drivername, uri)
} }
func (db *mssql) SqlType(c *Column) string { func (db *mssql) SqlType(c *Column) string {
@ -139,51 +121,48 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*Column, err
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale
from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id
where a.object_id=object_id('` + tableName + `')` where a.object_id=object_id('` + tableName + `')`
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...)
rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for rows.Next() {
var name, ctype, precision, scale string
var maxLen int
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale)
if err != nil {
return nil, nil, err
}
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { col.Length = maxLen
switch name { col.Name = strings.Trim(name, "` ")
case "name":
col.Name = strings.Trim(string(content), "` ") ct := strings.ToUpper(ctype)
case "ctype": switch ct {
ct := strings.ToUpper(string(content)) case "DATETIMEOFFSET":
switch ct { col.SQLType = SQLType{TimeStampz, 0, 0}
case "DATETIMEOFFSET": case "NVARCHAR":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = SQLType{Varchar, 0, 0}
case "NVARCHAR": case "IMAGE":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = SQLType{VarBinary, 0, 0}
case "IMAGE": default:
col.SQLType = SQLType{VarBinary, 0, 0} if _, ok := SqlTypes[ct]; ok {
default: col.SQLType = SQLType{ct, 0, 0}
if _, ok := sqlTypes[ct]; ok { } else {
col.SQLType = SQLType{ct, 0, 0} return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v",
} else { ct, tableName, col.Name))
return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v",
ct, tableName, col.Name))
}
}
case "max_length":
len1, err := strconv.Atoi(strings.TrimSpace(string(content)))
if err != nil {
return nil, nil, err
}
col.Length = len1
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
@ -198,25 +177,25 @@ where a.object_id=object_id('` + tableName + `')`
func (db *mssql) GetTables() ([]*Table, error) { func (db *mssql) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := `select name from sysobjects where xtype ='U'` s := `select name from sysobjects where xtype ='U'`
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for rows.Next() {
table := new(Table) table := NewEmptyTable()
for name, content := range record { var name string
switch name { err = rows.Scan(&name)
case "name": if err != nil {
table.Name = strings.Trim(string(content), "` ") return nil, err
}
} }
table.Name = strings.Trim(name, "` ")
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
@ -238,40 +217,39 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID
AND IXCS.COLUMN_ID=C.COLUMN_ID AND IXCS.COLUMN_ID=C.COLUMN_ID
WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
` `
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for rows.Next() {
var indexType int var indexType int
var indexName, colName string var indexName, colName, isUnique string
for name, content := range record {
switch name {
case "IS_UNIQUE":
i, err := strconv.ParseBool(string(content))
if err != nil {
return nil, err
}
if i { err = rows.Scan(&indexName, &colName, &isUnique, nil)
indexType = UniqueType if err != nil {
} else { return nil, err
indexType = IndexType
}
case "INDEX_NAME":
indexName = string(content)
case "COLUMN_NAME":
colName = strings.Trim(string(content), "` ")
}
} }
i, err := strconv.ParseBool(isUnique)
if err != nil {
return nil, err
}
if i {
indexType = UniqueType
} else {
indexType = IndexType
}
colName = strings.Trim(colName, "` ")
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
@ -288,3 +266,41 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
} }
return indexes, nil return indexes, nil
} }
func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset string) string {
var sql string
if tableName == "" {
tableName = table.Name
}
sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + tableName + "' ) CREATE TABLE"
sql += db.QuoteStr() + tableName + db.QuoteStr() + " ("
pkList := table.PrimaryKeys
for _, colName := range table.ColumnsSeq() {
col := table.GetColumn(colName)
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(db)
} else {
sql += col.StringNoPk(db)
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += strings.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
sql += ";"
return sql
}
func (db *mssql) Filters() []Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}}
}

276
dialects/mysql.go Normal file
View File

@ -0,0 +1,276 @@
package dialects
import (
"crypto/tls"
"errors"
"fmt"
"strconv"
"strings"
"time"
. "github.com/lunny/xorm/core"
)
func init() {
RegisterDialect("mysql", &mysql{})
}
type mysql struct {
Base
net string
addr string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
func (db *mysql) Init(uri *Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
}
func (db *mysql) SqlType(c *Column) string {
var res string
switch t := c.SQLType.Name; t {
case Bool:
res = TinyInt
case Serial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = Int
case BigSerial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = BigInt
case Bytea:
res = Blob
case TimeStampz:
res = Char
c.Length = 64
default:
res = t
}
var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0)
if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
}
return res
}
func (db *mysql) SupportInsertMany() bool {
return true
}
func (db *mysql) QuoteStr() string {
return "`"
}
func (db *mysql) SupportEngine() bool {
return true
}
func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
func (db *mysql) SupportCharset() bool {
return true
}
func (db *mysql) IndexOnTable() bool {
return true
}
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args
}
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args
}
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.DbName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args
}
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.DbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, nil, err
}
defer cnn.Close()
rows, err := cnn.Query(s, args...)
if err != nil {
return nil, nil, err
}
cols := make(map[string]*Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(Column)
col.Indexes = make(map[string]bool)
var columnName, isNullable, colType, colKey, extra string
var colDefault *string
err = rows.Scan(&columnName, &isNullable, &colDefault, &colType, &colKey, &extra)
if err != nil {
return nil, nil, err
}
col.Name = strings.Trim(columnName, "` ")
if "YES" == isNullable {
col.Nullable = true
}
if colDefault != nil {
col.Default = *colDefault
}
cts := strings.Split(colType, "(")
var len1, len2 int
if len(cts) == 2 {
idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil {
return nil, nil, err
}
if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1])
if err != nil {
return nil, nil, err
}
}
}
colName := cts[0]
colType = strings.ToUpper(colName)
col.Length = len1
col.Length2 = len2
if _, ok := SqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2}
} else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
}
if colKey == "PRI" {
col.IsPrimaryKey = true
}
if colKey == "UNI" {
//col.is
}
if extra == "auto_increment" {
col.IsAutoIncrement = true
}
if col.SQLType.IsText() {
if col.Default != "" {
col.Default = "'" + col.Default + "'"
}
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
return colSeq, cols, nil
}
func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.DbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
defer cnn.Close()
rows, err := cnn.Query(s, args...)
if err != nil {
return nil, err
}
tables := make([]*Table, 0)
for rows.Next() {
table := NewEmptyTable()
var name, engine, tableRows string
var autoIncr *string
err = rows.Scan(&name, &engine, &tableRows, &autoIncr)
if err != nil {
return nil, err
}
table.Name = name
tables = append(tables, table)
}
return tables, nil
}
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.DbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
defer cnn.Close()
rows, err := cnn.Query(s, args...)
if err != nil {
return nil, err
}
indexes := make(map[string]*Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, nonUnique string
err = rows.Scan(&indexName, &nonUnique, &colName)
if err != nil {
return nil, err
}
if indexName == "PRIMARY" {
continue
}
if "YES" == nonUnique || nonUnique == "1" {
indexType = IndexType
} else {
indexType = UniqueType
}
colName = strings.Trim(colName, "` ")
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)]
}
var index *Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(Index)
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
}
index.AddColumn(colName)
}
return indexes, nil
}
func (db *mysql) Filters() []Filter {
return []Filter{&IdFilter{}}
}

View File

@ -1,258 +1,250 @@
package xorm package dialects
import ( import (
"database/sql" "errors"
"errors" "fmt"
"fmt" "strconv"
"regexp" "strings"
"strconv"
"strings" . "github.com/lunny/xorm/core"
) )
type oracle struct { func init() {
base RegisterDialect("oracle", &oracle{})
} }
type oracleParser struct { type oracle struct {
} Base
}
//dataSourceName=user/password@ipv4:port/dbname
//dataSourceName=user/password@[ipv6]:port/dbname func (db *oracle) Init(uri *Uri, drivername, dataSourceName string) error {
func (p *oracleParser) parse(driverName, dataSourceName string) (*uri, error) { return db.Base.Init(db, uri, drivername, dataSourceName)
db := &uri{dbType: ORACLE_OCI} }
dsnPattern := regexp.MustCompile(
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@ func (db *oracle) SqlType(c *Column) string {
`(?P<net>.*)` + // ip:port var res string
`\/(?P<dbname>.*)`) // dbname switch t := c.SQLType.Name; t {
matches := dsnPattern.FindStringSubmatch(dataSourceName) case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial:
names := dsnPattern.SubexpNames() return "NUMBER"
for i, match := range matches { case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea:
switch names[i] { return Blob
case "dbname": case Time, DateTime, TimeStamp:
db.dbName = match res = TimeStamp
} case TimeStampz:
} res = "TIMESTAMP WITH TIME ZONE"
if db.dbName == "" { case Float, Double, Numeric, Decimal:
return nil, errors.New("dbname is empty") res = "NUMBER"
} case Text, MediumText, LongText:
return db, nil res = "CLOB"
} case Char, Varchar, TinyText:
return "VARCHAR2"
func (db *oracle) Init(drivername, uri string) error { default:
return db.base.init(&oracleParser{}, drivername, uri) res = t
} }
func (db *oracle) SqlType(c *Column) string { var hasLen1 bool = (c.Length > 0)
var res string var hasLen2 bool = (c.Length2 > 0)
switch t := c.SQLType.Name; t { if hasLen1 {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial: res += "(" + strconv.Itoa(c.Length) + ")"
return "NUMBER" } else if hasLen2 {
case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea: res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
return Blob }
case Time, DateTime, TimeStamp: return res
res = TimeStamp }
case TimeStampz:
res = "TIMESTAMP WITH TIME ZONE" func (db *oracle) SupportInsertMany() bool {
case Float, Double, Numeric, Decimal: return true
res = "NUMBER" }
case Text, MediumText, LongText:
res = "CLOB" func (db *oracle) QuoteStr() string {
case Char, Varchar, TinyText: return "\""
return "VARCHAR2" }
default:
res = t func (db *oracle) AutoIncrStr() string {
} return ""
}
var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) func (db *oracle) SupportEngine() bool {
if hasLen1 { return false
res += "(" + strconv.Itoa(c.Length) + ")" }
} else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" func (db *oracle) SupportCharset() bool {
} return false
return res }
}
func (db *oracle) IndexOnTable() bool {
func (db *oracle) SupportInsertMany() bool { return false
return true }
}
func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
func (db *oracle) QuoteStr() string { args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)}
return "\"" return `SELECT INDEX_NAME FROM USER_INDEXES ` +
} `WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args
}
func (db *oracle) AutoIncrStr() string {
return "" func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) {
} args := []interface{}{strings.ToUpper(tableName)}
return `SELECT table_name FROM user_tables WHERE table_name = ?`, args
func (db *oracle) SupportEngine() bool { }
return false
} func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)}
func (db *oracle) SupportCharset() bool { return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" +
return false " AND column_name = ?", args
} }
func (db *oracle) IndexOnTable() bool { func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) {
return false args := []interface{}{strings.ToUpper(tableName)}
} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
func (db *oracle) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(idxName)} cnn, err := Open(db.DriverName(), db.DataSourceName())
return `SELECT INDEX_NAME FROM USER_INDEXES ` + if err != nil {
`WHERE TABLE_NAME = ? AND INDEX_NAME = ?`, args return nil, nil, err
} }
defer cnn.Close()
func (db *oracle) TableCheckSql(tableName string) (string, []interface{}) { rows, err := cnn.Query(s, args...)
args := []interface{}{strings.ToUpper(tableName)} if err != nil {
return `SELECT table_name FROM user_tables WHERE table_name = ?`, args return nil, nil, err
} }
defer rows.Close()
func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{strings.ToUpper(tableName), strings.ToUpper(colName)} cols := make(map[string]*Column)
return "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = ?" + colSeq := make([]string, 0)
" AND column_name = ?", args for rows.Next() {
} col := new(Column)
col.Indexes = make(map[string]bool)
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{strings.ToUpper(tableName)} var colName, colDefault, nullable, dataType, dataPrecision, dataScale string
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + var dataLen int
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
cnn, err := sql.Open(db.driverName, db.dataSourceName) &dataScale, &nullable)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close()
res, err := query(cnn, s, args...) col.Name = strings.Trim(colName, `" `)
if err != nil { col.Default = colDefault
return nil, nil, err
} if nullable == "Y" {
cols := make(map[string]*Column) col.Nullable = true
colSeq := make([]string, 0) } else {
for _, record := range res { col.Nullable = false
col := new(Column) }
col.Indexes = make(map[string]bool)
for name, content := range record { switch dataType {
switch name { case "VARCHAR2":
case "column_name": col.SQLType = SQLType{Varchar, 0, 0}
col.Name = strings.Trim(string(content), `" `) case "TIMESTAMP WITH TIME ZONE":
case "data_default": col.SQLType = SQLType{TimeStampz, 0, 0}
col.Default = string(content) default:
case "nullable": col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0}
if string(content) == "Y" { }
col.Nullable = true if _, ok := SqlTypes[col.SQLType.Name]; !ok {
} else { return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
col.Nullable = false }
}
case "data_type": col.Length = dataLen
ct := string(content)
switch ct { if col.SQLType.IsText() {
case "VARCHAR2": if col.Default != "" {
col.SQLType = SQLType{Varchar, 0, 0} col.Default = "'" + col.Default + "'"
case "TIMESTAMP WITH TIME ZONE": }
col.SQLType = SQLType{TimeStamp, 0, 0} }
default: cols[col.Name] = col
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} colSeq = append(colSeq, col.Name)
} }
if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) return colSeq, cols, nil
} }
case "data_length":
i, err := strconv.Atoi(string(content)) func (db *oracle) GetTables() ([]*Table, error) {
if err != nil { args := []interface{}{}
return nil, nil, errors.New("retrieve length error") s := "SELECT table_name FROM user_tables"
} cnn, err := Open(db.DriverName(), db.DataSourceName())
col.Length = i if err != nil {
case "data_precision": return nil, err
case "data_scale": }
} defer cnn.Close()
} rows, err := cnn.Query(s, args...)
if col.SQLType.IsText() { if err != nil {
if col.Default != "" { return nil, err
col.Default = "'" + col.Default + "'" }
}
} tables := make([]*Table, 0)
cols[col.Name] = col for rows.Next() {
colSeq = append(colSeq, col.Name) table := NewEmptyTable()
} err = rows.Scan(&table.Name)
if err != nil {
return colSeq, cols, nil return nil, err
} }
func (db *oracle) GetTables() ([]*Table, error) { tables = append(tables, table)
args := []interface{}{} }
s := "SELECT table_name FROM user_tables" return tables, nil
cnn, err := sql.Open(db.driverName, db.dataSourceName) }
if err != nil {
return nil, err func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
} args := []interface{}{tableName}
defer cnn.Close() s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
res, err := query(cnn, s, args...) "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
if err != nil {
return nil, err cnn, err := Open(db.DriverName(), db.DataSourceName())
} if err != nil {
return nil, err
tables := make([]*Table, 0) }
for _, record := range res { defer cnn.Close()
table := new(Table) rows, err := cnn.Query(s, args...)
for name, content := range record { if err != nil {
switch name { return nil, err
case "table_name": }
table.Name = string(content) defer rows.Close()
}
} indexes := make(map[string]*Index, 0)
tables = append(tables, table) for rows.Next() {
} var indexType int
return tables, nil var indexName, colName, uniqueness string
}
err = rows.Scan(&colName, &uniqueness, &indexName)
func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { if err != nil {
args := []interface{}{tableName} return nil, err
s := "SELECT t.column_name,i.table_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + }
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
indexName = strings.Trim(indexName, `" `)
cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if uniqueness == "UNIQUE" {
return nil, err indexType = UniqueType
} } else {
defer cnn.Close() indexType = IndexType
res, err := query(cnn, s, args...) }
if err != nil {
return nil, err var index *Index
} var ok bool
if index, ok = indexes[indexName]; !ok {
indexes := make(map[string]*Index, 0) index = new(Index)
for _, record := range res { index.Type = indexType
var indexType int index.Name = indexName
var indexName string indexes[indexName] = index
var colName string }
index.AddColumn(colName)
for name, content := range record { }
switch name { return indexes, nil
case "index_name": }
indexName = strings.Trim(string(content), `" `)
case "uniqueness": // PgSeqFilter filter SQL replace ?, ? ... to :1, :2 ...
c := string(content) type OracleSeqFilter struct {
if c == "UNIQUE" { }
indexType = UniqueType
} else { func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
indexType = IndexType counts := strings.Count(sql, "?")
} for i := 1; i <= counts; i++ {
case "column_name": newstr := ":" + fmt.Sprintf("%v", i)
colName = string(content) sql = strings.Replace(sql, "?", newstr, 1)
} }
} return sql
}
var index *Index
var ok bool func (db *oracle) Filters() []Filter {
if index, ok = indexes[indexName]; !ok { return []Filter{&QuoteFilter{}, &OracleSeqFilter{}, &IdFilter{}}
index = new(Index) }
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
}
index.AddColumn(colName)
}
return indexes, nil
}

View File

@ -1,65 +1,24 @@
package xorm package dialects
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
. "github.com/lunny/xorm/core"
) )
func init() {
RegisterDialect("postgres", &postgres{})
}
type postgres struct { type postgres struct {
base Base
} }
type values map[string]string func (db *postgres) Init(uri *Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseOpts(name string, o values) {
if len(name) == 0 {
return
}
name = strings.TrimSpace(name)
ps := strings.Split(name, " ")
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
}
type postgresParser struct {
}
func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) {
db := &uri{dbType: POSTGRES}
o := make(values)
parseOpts(dataSourceName, o)
db.dbName = o.Get("dbname")
if db.dbName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
}
func (db *postgres) Init(drivername, uri string) error {
return db.base.init(&postgresParser{}, drivername, uri)
} }
func (db *postgres) SqlType(c *Column) string { func (db *postgres) SqlType(c *Column) string {
@ -153,68 +112,74 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res {
for rows.Next() {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { var colName, isNullable, dataType string
switch name { var maxLenStr, colDefault, numPrecision, numRadix *string
case "column_name": err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &numPrecision, &numRadix)
col.Name = strings.Trim(string(content), `" `) if err != nil {
case "column_default": return nil, nil, err
if strings.HasPrefix(string(content), "nextval") { }
col.IsPrimaryKey = true
} else { var maxLen int
col.Default = string(content) if maxLenStr != nil {
} maxLen, err = strconv.Atoi(*maxLenStr)
case "is_nullable": if err != nil {
if string(content) == "YES" { return nil, nil, err
col.Nullable = true
} else {
col.Nullable = false
}
case "data_type":
ct := string(content)
switch ct {
case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0}
case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0}
case "double precision":
col.SQLType = SQLType{Double, 0, 0}
case "boolean":
col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone":
col.SQLType = SQLType{Time, 0, 0}
default:
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
}
if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
}
case "character_maximum_length":
i, err := strconv.Atoi(string(content))
if err != nil {
return nil, nil, errors.New("retrieve length error")
}
col.Length = i
case "numeric_precision":
case "numeric_precision_radix":
} }
} }
col.Name = strings.Trim(colName, `" `)
if colDefault != nil {
if strings.HasPrefix(*colDefault, "nextval") {
col.IsPrimaryKey = true
} else {
col.Default = *colDefault
}
}
if isNullable == "YES" {
col.Nullable = true
} else {
col.Nullable = false
}
switch dataType {
case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0}
case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0}
case "double precision":
col.SQLType = SQLType{Double, 0, 0}
case "boolean":
col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone":
col.SQLType = SQLType{Time, 0, 0}
default:
col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0}
}
if _, ok := SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
}
col.Length = maxLen
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
@ -230,25 +195,25 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'" s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for rows.Next() {
table := new(Table) table := NewEmptyTable()
for name, content := range record { var name string
switch name { err = rows.Scan(&name)
case "tablename": if err != nil {
table.Name = string(content) return nil, err
}
} }
table.Name = name
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
@ -256,39 +221,37 @@ func (db *postgres) GetTables() ([]*Table, error) {
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" s := "SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for rows.Next() {
var indexType int var indexType int
var indexName string var indexName, indexdef string
var colNames []string var colNames []string
err = rows.Scan(&indexName, &indexdef)
for name, content := range record { if err != nil {
switch name { return nil, err
case "indexname":
indexName = strings.Trim(string(content), `" `)
case "indexdef":
c := string(content)
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
indexType = UniqueType
} else {
indexType = IndexType
}
cs := strings.Split(c, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
}
} }
indexName = strings.Trim(indexName, `" `)
if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") {
indexType = UniqueType
} else {
indexType = IndexType
}
cs := strings.Split(indexdef, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
@ -307,3 +270,24 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
} }
return indexes, nil return indexes, nil
} }
// PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type PgSeqFilter struct {
}
func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
segs := strings.Split(sql, "?")
size := len(segs)
res := ""
for i, c := range segs {
if i < size-1 {
res += c + fmt.Sprintf("$%v", i+1)
}
}
res += segs[size-1]
return res
}
func (db *postgres) Filters() []Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}, &PgSeqFilter{}}
}

View File

@ -1,23 +1,21 @@
package xorm package dialects
import ( import (
"database/sql"
"strings" "strings"
. "github.com/lunny/xorm/core"
) )
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
type sqlite3 struct { type sqlite3 struct {
base Base
} }
type sqlite3Parser struct { func (db *sqlite3) Init(uri *Uri, drivername, dataSourceName string) error {
} return db.Base.Init(db, uri, drivername, dataSourceName)
func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) {
return &uri{dbType: SQLITE, dbName: dataSourceName}, nil
}
func (db *sqlite3) Init(drivername, dataSourceName string) error {
return db.base.init(&sqlite3Parser{}, drivername, dataSourceName)
} }
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *Column) string {
@ -89,28 +87,29 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...)
rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer rows.Close()
var sql string var name string
for _, record := range res { for rows.Next() {
for name, content := range record { err = rows.Scan(&name)
if name == "sql" { if err != nil {
sql = string(content) return nil, nil, err
}
} }
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(name, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(name, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",") colCreates := strings.Split(name[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
@ -148,24 +147,23 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for rows.Next() {
table := new(Table) table := NewEmptyTable()
for name, content := range record { err = rows.Scan(&table.Name)
switch name { if err != nil {
case "name": return nil, err
table.Name = string(content)
}
} }
if table.Name == "sqlite_sequence" { if table.Name == "sqlite_sequence" {
continue continue
@ -178,25 +176,30 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName) cnn, err := Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) rows, err := cnn.Query(s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for rows.Next() {
index := new(Index) var sql string
sql := string(record["sql"]) err = rows.Scan(&sql)
if err != nil {
return nil, err
}
if sql == "" { if sql == "" {
continue continue
} }
index := new(Index)
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
if nNStart == -1 || nNEnd == -1 { if nNStart == -1 || nNEnd == -1 {
@ -230,3 +233,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
return indexes, nil return indexes, nil
} }
func (db *sqlite3) Filters() []Filter {
return []Filter{&IdFilter{}}
}

View File

@ -1,65 +1,65 @@
When a struct auto mapping to a database's table, the below table describes how they change to each other: When a struct auto mapping to a database's table, the below table describes how they change to each other:
<table> <table>
<tr> <tr>
<td>go type's kind <td>go type's kind
</td> </td>
<td>value method</td> <td>value method</td>
<td>xorm type <td>xorm type
</td> </td>
</tr> </tr>
<tr> <tr>
<td>implemented Conversion</td> <td>implemented Conversion</td>
<td>Conversion.ToDB / Conversion.FromDB</td> <td>Conversion.ToDB / Conversion.FromDB</td>
<td>Text</td> <td>Text</td>
</tr> </tr>
<tr> <tr>
<td>int, int8, int16, int32, uint, uint8, uint16, uint32</td> <td>int, int8, int16, int32, uint, uint8, uint16, uint32</td>
<td></td> <td></td>
<td> Int </td> <td> Int </td>
</tr> </tr>
<tr> <tr>
<td>int64, uint64</td><td></td><td>BigInt</td> <td>int64, uint64</td><td></td><td>BigInt</td>
</tr> </tr>
<tr><td>float32</td><td></td><td>Float</td> <tr><td>float32</td><td></td><td>Float</td>
</tr> </tr>
<tr><td>float64</td><td></td><td>Double</td> <tr><td>float64</td><td></td><td>Double</td>
</tr> </tr>
<tr><td>complex64, complex128</td> <tr><td>complex64, complex128</td>
<td>json.Marshal / json.UnMarshal</td> <td>json.Marshal / json.UnMarshal</td>
<td>Varchar(64)</td> <td>Varchar(64)</td>
</tr> </tr>
<tr> <tr>
<td>[]uint8</td><td></td><td>Blob</td> <td>[]uint8</td><td></td><td>Blob</td>
</tr> </tr>
<tr> <tr>
<td>array, slice, map except []uint8</td> <td>array, slice, map except []uint8</td>
<td>json.Marshal / json.UnMarshal</td> <td>json.Marshal / json.UnMarshal</td>
<td>Text</td> <td>Text</td>
</tr> </tr>
<tr> <tr>
<td>bool</td><td>1 or 0</td><td>Bool</td> <td>bool</td><td>1 or 0</td><td>Bool</td>
</tr> </tr>
<tr> <tr>
<td>string</td><td></td><td>Varchar(255)</td> <td>string</td><td></td><td>Varchar(255)</td>
</tr> </tr>
<tr> <tr>
<td>time.Time</td><td></td><td>DateTime</td> <td>time.Time</td><td></td><td>DateTime</td>
</tr> </tr>
<tr> <tr>
<td>cascade struct</td><td>primary key field value</td><td>BigInt</td> <td>cascade struct</td><td>primary key field value</td><td>BigInt</td>
</tr> </tr>
<tr> <tr>
<tr> <tr>
<td>struct</td><td>json.Marshal / json.UnMarshal</td><td>Text</td> <td>struct</td><td>json.Marshal / json.UnMarshal</td><td>Text</td>
</tr> </tr>
<tr> <tr>
<td> <td>
Others Others
</td> </td>
<td></td> <td></td>
<td> <td>
Text Text
</td> </td>
</tr> </tr>
</table> </table>

View File

@ -1,438 +1,438 @@
<table> <table>
<tr> <tr>
<td>xorm <td>xorm
</td> </td>
<td>mysql <td>mysql
</td> </td>
<td>sqlite3 <td>sqlite3
</td> </td>
<td>postgres <td>postgres
</td> </td>
<td>remark</td> <td>remark</td>
</tr> </tr>
<tr> <tr>
<td>BIT <td>BIT
</td> </td>
<td>BIT <td>BIT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>BIT <td>BIT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>TINYINT <td>TINYINT
</td> </td>
<td>TINYINT <td>TINYINT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>SMALLINT <td>SMALLINT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>SMALLINT <td>SMALLINT
</td> </td>
<td>SMALLINT <td>SMALLINT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>SMALLINT <td>SMALLINT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>MEDIUMINT <td>MEDIUMINT
</td> </td>
<td>MEDIUMINT <td>MEDIUMINT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>INT <td>INT
</td> </td>
<td>INT <td>INT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>INTEGER <td>INTEGER
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>BIGINT <td>BIGINT
</td> </td>
<td>BIGINT <td>BIGINT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>BIGINT <td>BIGINT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>CHAR <td>CHAR
</td> </td>
<td>CHAR <td>CHAR
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>CHAR <td>CHAR
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>VARCHAR <td>VARCHAR
</td> </td>
<td>VARCHAR <td>VARCHAR
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>VARCHAR <td>VARCHAR
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>TINYTEXT <td>TINYTEXT
</td> </td>
<td>TINYTEXT <td>TINYTEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>MEDIUMTEXT <td>MEDIUMTEXT
</td> </td>
<td>MEDIUMTEXT <td>MEDIUMTEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>LONGTEXT <td>LONGTEXT
</td> </td>
<td>LONGTEXT <td>LONGTEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>BINARY <td>BINARY
</td> </td>
<td>BINARY <td>BINARY
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>VARBINARY <td>VARBINARY
</td> </td>
<td>VARBINARY <td>VARBINARY
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>DATE <td>DATE
</td> </td>
<td>DATE <td>DATE
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>DATE <td>DATE
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>DATETIME <td>DATETIME
</td> </td>
<td>DATETIME <td>DATETIME
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>TIMESTAMP <td>TIMESTAMP
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>TIME <td>TIME
</td> </td>
<td>TIME <td>TIME
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>TIME <td>TIME
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>TIMESTAMP <td>TIMESTAMP
</td> </td>
<td>TIMESTAMP <td>TIMESTAMP
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>TIMESTAMP <td>TIMESTAMP
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>TIMESTAMPZ <td>TIMESTAMPZ
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TEXT <td>TEXT
</td> </td>
<td>TIMESTAMP with zone <td>TIMESTAMP with zone
</td> </td>
<td>timestamp with zone info</td> <td>timestamp with zone info</td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>REAL <td>REAL
</td> </td>
<td>REAL <td>REAL
</td> </td>
<td>REAL <td>REAL
</td> </td>
<td>REAL <td>REAL
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>FLOAT <td>FLOAT
</td> </td>
<td>FLOAT <td>FLOAT
</td> </td>
<td>REAL <td>REAL
</td> </td>
<td>REAL <td>REAL
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>DOUBLE <td>DOUBLE
</td> </td>
<td>DOUBLE <td>DOUBLE
</td> </td>
<td>REAL <td>REAL
</td> </td>
<td>DOUBLE PRECISION <td>DOUBLE PRECISION
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>DECIMAL <td>DECIMAL
</td> </td>
<td>DECIMAL <td>DECIMAL
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>DECIMAL <td>DECIMAL
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td>NUMERIC <td>NUMERIC
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>TINYBLOB <td>TINYBLOB
</td> </td>
<td>TINYBLOB <td>TINYBLOB
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>BLOB <td>BLOB
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>MEDIUMBLOB <td>MEDIUMBLOB
</td> </td>
<td>MEDIUMBLOB <td>MEDIUMBLOB
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>LONGBLOB <td>LONGBLOB
</td> </td>
<td>LONGBLOB <td>LONGBLOB
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>BYTEA <td>BYTEA
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BLOB <td>BLOB
</td> </td>
<td>BYTEA <td>BYTEA
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr><td cols="5"></td></tr> <tr><td cols="5"></td></tr>
<tr> <tr>
<td>BOOL <td>BOOL
</td> </td>
<td>TINYINT <td>TINYINT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>BOOLEAN <td>BOOLEAN
</td> </td>
<td></td> <td></td>
</tr> </tr>
<tr> <tr>
<td>SERIAL <td>SERIAL
</td> </td>
<td>INT <td>INT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>SERIAL <td>SERIAL
</td> </td>
<td>auto increment</td> <td>auto increment</td>
</tr> </tr>
<tr> <tr>
<td>BIGSERIAL <td>BIGSERIAL
</td> </td>
<td>BIGINT <td>BIGINT
</td> </td>
<td>INTEGER <td>INTEGER
</td> </td>
<td>BIGSERIAL <td>BIGSERIAL
</td> </td>
<td>auto increment</td> <td>auto increment</td>
</tr> </tr>
</table> </table>

View File

@ -1,31 +1,31 @@
## 更新日志 ## 更新日志
* **v0.3.1** * **v0.3.1**
新特性: 新特性:
* 支持 MSSQL DB 通过 ODBC 驱动 ([github.com/lunny/godbc](https://github.com/lunny/godbc)); * 支持 MSSQL DB 通过 ODBC 驱动 ([github.com/lunny/godbc](https://github.com/lunny/godbc));
* 通过多个pk标记支持联合主键; * 通过多个pk标记支持联合主键;
* 新增 Rows() API 用来遍历查询结果该函数提供了类似sql.Rows的相似用法可作为 Iterate() API 的可选替代; * 新增 Rows() API 用来遍历查询结果该函数提供了类似sql.Rows的相似用法可作为 Iterate() API 的可选替代;
* ORM 结构体现在允许内建类型的指针作为成员使得数据库为null成为可能 * ORM 结构体现在允许内建类型的指针作为成员使得数据库为null成为可能
* Before 和 After 支持 * Before 和 After 支持
改进: 改进:
* 允许 int/int32/int64/uint/uint32/uint64/string 作为主键类型 * 允许 int/int32/int64/uint/uint32/uint64/string 作为主键类型
* 查询函数 Get()/Find()/Iterate() 在性能上的改进 * 查询函数 Get()/Find()/Iterate() 在性能上的改进
* **v0.2.3** : 改善了文档提供了乐观锁支持添加了带时区时间字段支持Mapper现在分成表名Mapper和字段名Mapper同时实现了表或字段的自定义前缀后缀Insert方法的返回值含义从id, err更改为 affected, err请大家注意添加了UseBool 和 Distinct函数。 * **v0.2.3** : 改善了文档提供了乐观锁支持添加了带时区时间字段支持Mapper现在分成表名Mapper和字段名Mapper同时实现了表或字段的自定义前缀后缀Insert方法的返回值含义从id, err更改为 affected, err请大家注意添加了UseBool 和 Distinct函数。
* **v0.2.2** : Postgres驱动新增了对lib/pq的支持新增了逐条遍历方法Iterate新增了SetMaxConns(go1.2+)支持修复了bug若干 * **v0.2.2** : Postgres驱动新增了对lib/pq的支持新增了逐条遍历方法Iterate新增了SetMaxConns(go1.2+)支持修复了bug若干
* **v0.2.1** : 新增数据库反转工具当前支持go和c++代码的生成,详见 [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md); 修复了一些bug. * **v0.2.1** : 新增数据库反转工具当前支持go和c++代码的生成,详见 [Xorm Tool README](https://github.com/lunny/xorm/blob/master/xorm/README.md); 修复了一些bug.
* **v0.2.0** : 新增 [缓存](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#120)支持查询速度提升3-5倍 新增数据库表和Struct同名的映射方式 新增Sync同步表结构 * **v0.2.0** : 新增 [缓存](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#120)支持查询速度提升3-5倍 新增数据库表和Struct同名的映射方式 新增Sync同步表结构
* **v0.1.9** : 新增 postgres 和 mymysql 驱动支持; 在Postgres中支持原始SQL语句中使用 ` 和 ? 符号; 新增Cols, StoreEngine, Charset 函数SQL语句打印支持io.Writer接口默认打印到控制台新增更多的字段类型支持详见 [映射规则](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#21)删除废弃的MakeSession和Create函数。 * **v0.1.9** : 新增 postgres 和 mymysql 驱动支持; 在Postgres中支持原始SQL语句中使用 ` 和 ? 符号; 新增Cols, StoreEngine, Charset 函数SQL语句打印支持io.Writer接口默认打印到控制台新增更多的字段类型支持详见 [映射规则](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#21)删除废弃的MakeSession和Create函数。
* **v0.1.8** : 新增联合index联合unique支持请查看 [映射规则](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#21)。 * **v0.1.8** : 新增联合index联合unique支持请查看 [映射规则](https://github.com/lunny/xorm/blob/master/docs/QuickStart.md#21)。
* **v0.1.7** : 新增IConnectPool接口以及NoneConnectPool, SysConnectPool, SimpleConnectPool三种实现可以选择不使用连接池使用系统连接池和使用自带连接池三种实现默认为SysConnectPool即系统自带的连接池。同时支持自定义连接池。Engine新增Close方法在系统退出时应调用此方法。 * **v0.1.7** : 新增IConnectPool接口以及NoneConnectPool, SysConnectPool, SimpleConnectPool三种实现可以选择不使用连接池使用系统连接池和使用自带连接池三种实现默认为SysConnectPool即系统自带的连接池。同时支持自定义连接池。Engine新增Close方法在系统退出时应调用此方法。
* **v0.1.6** : 新增Conversion支持自定义类型到数据库类型的转换新增查询结构体自动检测匿名成员支持新增单向映射支持 * **v0.1.6** : 新增Conversion支持自定义类型到数据库类型的转换新增查询结构体自动检测匿名成员支持新增单向映射支持
* **v0.1.5** : 新增对多线程的支持新增Sql()函数支持任意sql语句的struct查询Get函数返回值变动MakeSession和Create函数被NewSession和NewEngine函数替代 * **v0.1.5** : 新增对多线程的支持新增Sql()函数支持任意sql语句的struct查询Get函数返回值变动MakeSession和Create函数被NewSession和NewEngine函数替代
* **v0.1.4** : Get函数和Find函数新增简单的级联载入功能对更多的数据库类型支持。 * **v0.1.4** : Get函数和Find函数新增简单的级联载入功能对更多的数据库类型支持。
* **v0.1.3** : Find函数现在支持传入Slice或者Map当传入Map时key为id新增Table函数以为多表和临时表进行支持。 * **v0.1.3** : Find函数现在支持传入Slice或者Map当传入Map时key为id新增Table函数以为多表和临时表进行支持。
* **v0.1.2** : Insert函数支持混合struct和slice指针传入并根据数据库类型自动批量插入同时自动添加事务 * **v0.1.2** : Insert函数支持混合struct和slice指针传入并根据数据库类型自动批量插入同时自动添加事务
* **v0.1.1** : 添加 Id, In 函数,改善 README 文档 * **v0.1.1** : 添加 Id, In 函数,改善 README 文档
* **v0.1.0** : 初始化工程 * **v0.1.0** : 初始化工程

File diff suppressed because it is too large Load Diff

38
drivers/goracle.go Normal file
View File

@ -0,0 +1,38 @@
package drivers
import (
"errors"
"regexp"
"github.com/lunny/xorm/core"
)
func init() {
core.RegisterDriver("goracle", &goracleDriver{})
}
type goracleDriver struct {
}
func (cfg *goracleDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
for i, match := range matches {
switch names[i] {
case "dbname":
db.DbName = match
}
}
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
}

View File

@ -1,20 +1,22 @@
package xorm package drivers
import ( import (
"errors" "errors"
"strings" "strings"
"time" "time"
"github.com/lunny/xorm/core"
) )
type mymysql struct { func init() {
mysql core.RegisterDriver("mymysql", &mymysqlDriver{})
} }
type mymysqlParser struct { type mymysqlDriver struct {
} }
func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) { func (p *mymysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &uri{dbType: MYSQL} db := &core.Uri{DbType: core.MYSQL}
pd := strings.SplitN(dataSourceName, "*", 2) pd := strings.SplitN(dataSourceName, "*", 2)
if len(pd) == 2 { if len(pd) == 2 {
@ -23,9 +25,9 @@ func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
if len(p) != 2 { if len(p) != 2 {
return nil, errors.New("Wrong protocol part of URI") return nil, errors.New("Wrong protocol part of URI")
} }
db.proto = p[0] db.Proto = p[0]
options := strings.Split(p[1], ",") options := strings.Split(p[1], ",")
db.raddr = options[0] db.Raddr = options[0]
for _, o := range options[1:] { for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2) kv := strings.SplitN(o, "=", 2)
var k, v string var k, v string
@ -36,13 +38,13 @@ func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
} }
switch k { switch k {
case "laddr": case "laddr":
db.laddr = v db.Laddr = v
case "timeout": case "timeout":
to, err := time.ParseDuration(v) to, err := time.ParseDuration(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.timeout = to db.Timeout = to
default: default:
return nil, errors.New("Unknown option: " + k) return nil, errors.New("Unknown option: " + k)
} }
@ -55,13 +57,9 @@ func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
if len(dup) != 3 { if len(dup) != 3 {
return nil, errors.New("Wrong database part of URI") return nil, errors.New("Wrong database part of URI")
} }
db.dbName = dup[0] db.DbName = dup[0]
db.user = dup[1] db.User = dup[1]
db.passwd = dup[2] db.Passwd = dup[2]
return db, nil return db, nil
} }
func (db *mymysql) Init(drivername, uri string) error {
return db.mysql.base.init(&mymysqlParser{}, drivername, uri)
}

50
drivers/mysql.go Normal file
View File

@ -0,0 +1,50 @@
package drivers
import (
"regexp"
"strings"
"github.com/lunny/xorm/core"
)
func init() {
core.RegisterDriver("mysql", &mysqlDriver{})
}
type mysqlDriver struct {
}
func (p *mysqlDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &core.Uri{DbType: core.MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.DbName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")
for _, kv := range kvs {
splits := strings.Split(kv, "=")
if len(splits) == 2 {
switch splits[0] {
case "charset":
uri.Charset = splits[1]
}
}
}
}
}
}
return uri, nil
}

37
drivers/oci.go Normal file
View File

@ -0,0 +1,37 @@
package drivers
import (
"errors"
"regexp"
"github.com/lunny/xorm/core"
)
func init() {
core.RegisterDriver("oci", &ociDriver{})
}
type ociDriver struct {
}
//dataSourceName=user/password@ipv4:port/dbname
//dataSourceName=user/password@[ipv6]:port/dbname
func (p *ociDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.ORACLE}
dsnPattern := regexp.MustCompile(
`^(?P<user>.*)\/(?P<password>.*)@` + // user:password@
`(?P<net>.*)` + // ip:port
`\/(?P<dbname>.*)`) // dbname
matches := dsnPattern.FindStringSubmatch(dataSourceName)
names := dsnPattern.SubexpNames()
for i, match := range matches {
switch names[i] {
case "dbname":
db.DbName = match
}
}
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
}

34
drivers/odbc.go Normal file
View File

@ -0,0 +1,34 @@
package drivers
import (
"errors"
"strings"
"github.com/lunny/xorm/core"
)
func init() {
core.RegisterDriver("odbc", &odbcDriver{})
}
type odbcDriver struct {
}
func (p *odbcDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
kv := strings.Split(dataSourceName, ";")
var dbName string
for _, c := range kv {
vv := strings.Split(strings.TrimSpace(c), "=")
if len(vv) == 2 {
switch strings.ToLower(vv[0]) {
case "database":
dbName = vv[1]
}
}
}
if dbName == "" {
return nil, errors.New("no db name provided")
}
return &core.Uri{DbName: dbName, DbType: core.MSSQL}, nil
}

59
drivers/pq.go Normal file
View File

@ -0,0 +1,59 @@
package drivers
import (
"errors"
"fmt"
"strings"
"github.com/lunny/xorm/core"
)
func init() {
core.RegisterDriver("postgres", &pqDriver{})
}
type pqDriver struct {
}
type values map[string]string
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseOpts(name string, o values) {
if len(name) == 0 {
return
}
name = strings.TrimSpace(name)
ps := strings.Split(name, " ")
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname")
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
}

16
drivers/sqlite3.go Normal file
View File

@ -0,0 +1,16 @@
package drivers
import (
"github.com/lunny/xorm/core"
)
func init() {
core.RegisterDriver("sqlite3", &sqlite3Driver{})
}
type sqlite3Driver struct {
}
func (p *sqlite3Driver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
return &core.Uri{DbType: core.SQLITE, DbName: dataSourceName}, nil
}

192
engine.go
View File

@ -12,40 +12,10 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"github.com/lunny/xorm/core"
) )
const (
POSTGRES = "postgres"
SQLITE = "sqlite3"
MYSQL = "mysql"
MYMYSQL = "mymysql"
MSSQL = "mssql"
ORACLE_OCI = "oci8"
)
// a dialect is a driver's wrapper
type dialect interface {
Init(DriverName, DataSourceName string) error
URI() *uri
DBType() string
SqlType(t *Column) string
SupportInsertMany() bool
QuoteStr() string
AutoIncrStr() string
SupportEngine() bool
SupportCharset() bool
IndexOnTable() bool
IndexCheckSql(tableName, idxName string) (string, []interface{})
TableCheckSql(tableName string) (string, []interface{})
ColumnCheckSql(tableName, colName string) (string, []interface{})
GetColumns(tableName string) ([]string, map[string]*Column, error)
GetTables() ([]*Table, error)
GetIndexes(tableName string) (map[string]*Index, error)
}
type PK []interface{} type PK []interface{}
// Engine is the major struct of xorm, it means a database manager. // Engine is the major struct of xorm, it means a database manager.
@ -56,18 +26,19 @@ type Engine struct {
TagIdentifier string TagIdentifier string
DriverName string DriverName string
DataSourceName string DataSourceName string
dialect dialect dialect core.Dialect
Tables map[reflect.Type]*Table Tables map[reflect.Type]*core.Table
mutex *sync.Mutex
ShowSQL bool mutex *sync.Mutex
ShowErr bool ShowSQL bool
ShowDebug bool ShowErr bool
ShowWarn bool ShowDebug bool
Pool IConnectPool ShowWarn bool
Filters []Filter Pool IConnectPool
Logger io.Writer Filters []core.Filter
Cacher Cacher Logger io.Writer
UseCache bool Cacher Cacher
tableCachers map[reflect.Type]Cacher
} }
func (engine *Engine) SetMapper(mapper IMapper) { func (engine *Engine) SetMapper(mapper IMapper) {
@ -102,8 +73,8 @@ func (engine *Engine) Quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
} }
// A simple wrapper to dialect's SqlType method // A simple wrapper to dialect's core.SqlType method
func (engine *Engine) SqlType(c *Column) string { func (engine *Engine) SqlType(c *core.Column) string {
return engine.dialect.SqlType(c) return engine.dialect.SqlType(c)
} }
@ -130,12 +101,7 @@ func (engine *Engine) SetMaxIdleConns(conns int) {
// SetDefaltCacher set the default cacher. Xorm's default not enable cacher. // SetDefaltCacher set the default cacher. Xorm's default not enable cacher.
func (engine *Engine) SetDefaultCacher(cacher Cacher) { func (engine *Engine) SetDefaultCacher(cacher Cacher) {
if cacher == nil { engine.Cacher = cacher
engine.UseCache = false
} else {
engine.UseCache = true
engine.Cacher = cacher
}
} }
// If you has set default cacher, and you want temporilly stop use cache, // If you has set default cacher, and you want temporilly stop use cache,
@ -156,7 +122,7 @@ func (engine *Engine) NoCascade() *Session {
func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) {
t := rType(bean) t := rType(bean)
engine.autoMapType(t) engine.autoMapType(t)
engine.Tables[t].Cacher = cacher engine.tableCachers[t] = cacher
} }
// OpenDB provides a interface to operate database directly. // OpenDB provides a interface to operate database directly.
@ -235,7 +201,7 @@ func (engine *Engine) NoAutoTime() *Session {
} }
// Retrieve all tables, columns, indexes' informations from database. // Retrieve all tables, columns, indexes' informations from database.
func (engine *Engine) DBMetas() ([]*Table, error) { func (engine *Engine) DBMetas() ([]*core.Table, error) {
tables, err := engine.dialect.GetTables() tables, err := engine.dialect.GetTables()
if err != nil { if err != nil {
return nil, err return nil, err
@ -246,8 +212,11 @@ func (engine *Engine) DBMetas() ([]*Table, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
table.Columns = cols for _, name := range colSeq {
table.ColumnsSeq = colSeq table.AddColumn(cols[name])
}
//table.Columns = cols
//table.ColumnsSeq = colSeq
indexes, err := engine.dialect.GetIndexes(table.Name) indexes, err := engine.dialect.GetIndexes(table.Name)
if err != nil { if err != nil {
@ -257,10 +226,10 @@ func (engine *Engine) DBMetas() ([]*Table, error) {
for _, index := range indexes { for _, index := range indexes {
for _, name := range index.Cols { for _, name := range index.Cols {
if col, ok := table.Columns[name]; ok { if col := table.GetColumn(name); col != nil {
col.Indexes[index.Name] = true col.Indexes[index.Name] = true
} else { } else {
return nil, fmt.Errorf("Unknown col "+name+" in indexes %v", table.Columns) return nil, fmt.Errorf("Unknown col "+name+" in indexes %v", index)
} }
} }
} }
@ -420,7 +389,7 @@ func (engine *Engine) Having(conditions string) *Session {
return session.Having(conditions) return session.Having(conditions)
} }
func (engine *Engine) autoMapType(t reflect.Type) *Table { func (engine *Engine) autoMapType(t reflect.Type) *core.Table {
engine.mutex.Lock() engine.mutex.Lock()
defer engine.mutex.Unlock() defer engine.mutex.Unlock()
table, ok := engine.Tables[t] table, ok := engine.Tables[t]
@ -431,37 +400,31 @@ func (engine *Engine) autoMapType(t reflect.Type) *Table {
return table return table
} }
func (engine *Engine) autoMap(bean interface{}) *Table { func (engine *Engine) autoMap(bean interface{}) *core.Table {
t := rType(bean) t := rType(bean)
return engine.autoMapType(t) return engine.autoMapType(t)
} }
func (engine *Engine) newTable() *Table { func (engine *Engine) mapType(t reflect.Type) *core.Table {
table := &Table{} return mappingTable(t, engine.tableMapper, engine.columnMapper, engine.dialect, engine.TagIdentifier)
table.Indexes = make(map[string]*Index)
table.Columns = make(map[string]*Column)
table.ColumnsSeq = make([]string, 0)
table.Created = make(map[string]bool)
table.Cacher = engine.Cacher
return table
} }
func (engine *Engine) mapType(t reflect.Type) *Table { func mappingTable(t reflect.Type, tableMapper IMapper, colMapper IMapper, dialect core.Dialect, tagId string) *core.Table {
table := engine.newTable() table := core.NewEmptyTable()
table.Name = engine.tableMapper.Obj2Table(t.Name()) table.Name = tableMapper.Obj2Table(t.Name())
table.Type = t table.Type = t
var idFieldColName string var idFieldColName string
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
tag := t.Field(i).Tag tag := t.Field(i).Tag
ormTagStr := tag.Get(engine.TagIdentifier) ormTagStr := tag.Get(tagId)
var col *Column var col *core.Column
fieldType := t.Field(i).Type fieldType := t.Field(i).Type
if ormTagStr != "" { if ormTagStr != "" {
col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false,
IsAutoIncrement: false, MapType: TWOSIDES, Indexes: make(map[string]bool)} IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)}
tags := strings.Split(ormTagStr, " ") tags := strings.Split(ormTagStr, " ")
if len(tags) > 0 { if len(tags) > 0 {
@ -470,25 +433,24 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
} }
if (strings.ToUpper(tags[0]) == "EXTENDS") && if (strings.ToUpper(tags[0]) == "EXTENDS") &&
(fieldType.Kind() == reflect.Struct) { (fieldType.Kind() == reflect.Struct) {
parentTable := engine.mapType(fieldType) parentTable := mappingTable(fieldType, tableMapper, colMapper, dialect, tagId)
for name, col := range parentTable.Columns { for _, col := range parentTable.Columns() {
col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName)
table.Columns[strings.ToLower(name)] = col table.AddColumn(col)
table.ColumnsSeq = append(table.ColumnsSeq, name)
} }
table.PrimaryKeys = parentTable.PrimaryKeys
continue continue
} }
var indexType int var indexType int
var indexName string var indexName string
for j, key := range tags { for j, key := range tags {
k := strings.ToUpper(key) k := strings.ToUpper(key)
switch { switch {
case k == "<-": case k == "<-":
col.MapType = ONLYFROMDB col.MapType = core.ONLYFROMDB
case k == "->": case k == "->":
col.MapType = ONLYTODB col.MapType = core.ONLYTODB
case k == "PK": case k == "PK":
col.IsPrimaryKey = true col.IsPrimaryKey = true
col.Nullable = false col.Nullable = false
@ -506,15 +468,15 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
case k == "UPDATED": case k == "UPDATED":
col.IsUpdated = true col.IsUpdated = true
case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"): case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
indexType = IndexType indexType = core.IndexType
indexName = k[len("INDEX")+1 : len(k)-1] indexName = k[len("INDEX")+1 : len(k)-1]
case k == "INDEX": case k == "INDEX":
indexType = IndexType indexType = core.IndexType
case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"): case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
indexName = k[len("UNIQUE")+1 : len(k)-1] indexName = k[len("UNIQUE")+1 : len(k)-1]
indexType = UniqueType indexType = core.UniqueType
case k == "UNIQUE": case k == "UNIQUE":
indexType = UniqueType indexType = core.UniqueType
case k == "NOTNULL": case k == "NOTNULL":
col.Nullable = false col.Nullable = false
case k == "NOT": case k == "NOT":
@ -525,10 +487,10 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
} }
} else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") { } else if strings.Contains(k, "(") && strings.HasSuffix(k, ")") {
fs := strings.Split(k, "(") fs := strings.Split(k, "(")
if _, ok := sqlTypes[fs[0]]; !ok { if _, ok := core.SqlTypes[fs[0]]; !ok {
continue continue
} }
col.SQLType = SQLType{fs[0], 0, 0} col.SQLType = core.SQLType{fs[0], 0, 0}
fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",")
if len(fs2) == 2 { if len(fs2) == 2 {
col.Length, _ = strconv.Atoi(fs2[0]) col.Length, _ = strconv.Atoi(fs2[0])
@ -537,17 +499,17 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
col.Length, _ = strconv.Atoi(fs2[0]) col.Length, _ = strconv.Atoi(fs2[0])
} }
} else { } else {
if _, ok := sqlTypes[k]; ok { if _, ok := core.SqlTypes[k]; ok {
col.SQLType = SQLType{k, 0, 0} col.SQLType = core.SQLType{k, 0, 0}
} else if key != col.Default { } else if key != col.Default {
col.Name = key col.Name = key
} }
} }
engine.SqlType(col) dialect.SqlType(col)
} }
} }
if col.SQLType.Name == "" { if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType) col.SQLType = core.Type2SQLType(fieldType)
} }
if col.Length == 0 { if col.Length == 0 {
col.Length = col.SQLType.DefaultLength col.Length = col.SQLType.DefaultLength
@ -556,9 +518,9 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
col.Length2 = col.SQLType.DefaultLength2 col.Length2 = col.SQLType.DefaultLength2
} }
if col.Name == "" { if col.Name == "" {
col.Name = engine.columnMapper.Obj2Table(t.Field(i).Name) col.Name = colMapper.Obj2Table(t.Field(i).Name)
} }
if indexType == IndexType { if indexType == core.IndexType {
if indexName == "" { if indexName == "" {
indexName = col.Name indexName = col.Name
} }
@ -566,12 +528,12 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
index.AddColumn(col.Name) index.AddColumn(col.Name)
col.Indexes[index.Name] = true col.Indexes[index.Name] = true
} else { } else {
index := NewIndex(indexName, IndexType) index := core.NewIndex(indexName, core.IndexType)
index.AddColumn(col.Name) index.AddColumn(col.Name)
table.AddIndex(index) table.AddIndex(index)
col.Indexes[index.Name] = true col.Indexes[index.Name] = true
} }
} else if indexType == UniqueType { } else if indexType == core.UniqueType {
if indexName == "" { if indexName == "" {
indexName = col.Name indexName = col.Name
} }
@ -579,7 +541,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
index.AddColumn(col.Name) index.AddColumn(col.Name)
col.Indexes[index.Name] = true col.Indexes[index.Name] = true
} else { } else {
index := NewIndex(indexName, UniqueType) index := core.NewIndex(indexName, core.UniqueType)
index.AddColumn(col.Name) index.AddColumn(col.Name)
table.AddIndex(index) table.AddIndex(index)
col.Indexes[index.Name] = true col.Indexes[index.Name] = true
@ -587,10 +549,9 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
} }
} }
} else { } else {
sqlType := Type2SQLType(fieldType) sqlType := core.Type2SQLType(fieldType)
col = &Column{engine.columnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, col = core.NewColumn(colMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false, sqlType.DefaultLength, sqlType.DefaultLength2, true)
TWOSIDES, false, false, false, false}
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
col.Nullable = false col.Nullable = false
@ -604,7 +565,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table {
} }
if idFieldColName != "" && len(table.PrimaryKeys) == 0 { if idFieldColName != "" && len(table.PrimaryKeys) == 0 {
col := table.Columns[strings.ToLower(idFieldColName)] col := table.GetColumn(idFieldColName)
col.IsPrimaryKey = true col.IsPrimaryKey = true
col.IsAutoIncrement = true col.IsAutoIncrement = true
col.Nullable = false col.Nullable = false
@ -666,6 +627,13 @@ func (engine *Engine) CreateUniques(bean interface{}) error {
return session.CreateUniques(bean) return session.CreateUniques(bean)
} }
func (engine *Engine) getCacher(t reflect.Type) Cacher {
if cacher, ok := engine.tableCachers[t]; ok {
return cacher
}
return engine.Cacher
}
// If enabled cache, clear the cache bean // If enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error {
t := rType(bean) t := rType(bean)
@ -673,9 +641,10 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error {
return errors.New("error params") return errors.New("error params")
} }
table := engine.autoMap(bean) table := engine.autoMap(bean)
if table.Cacher != nil { cacher := engine.getCacher(t)
table.Cacher.ClearIds(table.Name) if cacher != nil {
table.Cacher.DelBean(table.Name, id) cacher.ClearIds(table.Name)
cacher.DelBean(table.Name, id)
} }
return nil return nil
} }
@ -688,9 +657,10 @@ func (engine *Engine) ClearCache(beans ...interface{}) error {
return errors.New("error params") return errors.New("error params")
} }
table := engine.autoMap(bean) table := engine.autoMap(bean)
if table.Cacher != nil { cacher := engine.getCacher(t)
table.Cacher.ClearIds(table.Name) if cacher != nil {
table.Cacher.ClearBeans(table.Name) cacher.ClearIds(table.Name)
cacher.ClearBeans(table.Name)
} }
} }
return nil return nil
@ -730,7 +700,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
} else { } else {
for _, col := range table.Columns { for _, col := range table.Columns() {
session := engine.NewSession() session := engine.NewSession()
session.Statement.RefTable = table session.Statement.RefTable = table
defer session.Close() defer session.Close()
@ -753,7 +723,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession() session := engine.NewSession()
session.Statement.RefTable = table session.Statement.RefTable = table
defer session.Close() defer session.Close()
if index.Type == UniqueType { if index.Type == core.UniqueType {
//isExist, err := session.isIndexExist(table.Name, name, true) //isExist, err := session.isIndexExist(table.Name, name, true)
isExist, err := session.isIndexExist2(table.Name, index.Cols, true) isExist, err := session.isIndexExist2(table.Name, index.Cols, true)
if err != nil { if err != nil {
@ -768,7 +738,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err return err
} }
} }
} else if index.Type == IndexType { } else if index.Type == core.IndexType {
isExist, err := session.isIndexExist2(table.Name, index.Cols, false) isExist, err := session.isIndexExist2(table.Name, index.Cols, false)
if err != nil { if err != nil {
return err return err

View File

@ -1,109 +1,109 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func main() { func main() {
f := "cache.db" f := "cache.db"
os.Remove(f) os.Remove(f)
Orm, err := xorm.NewEngine("sqlite3", f) Orm, err := xorm.NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)
Orm.SetDefaultCacher(cacher) Orm.SetDefaultCacher(cacher)
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{Name: "xlw"}) _, err = Orm.Insert(&User{Name: "xlw"})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
users := make([]User, 0) users := make([]User, 0)
err = Orm.Find(&users) err = Orm.Find(&users)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("users:", users) fmt.Println("users:", users)
users2 := make([]User, 0) users2 := make([]User, 0)
err = Orm.Find(&users2) err = Orm.Find(&users2)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("users2:", users2) fmt.Println("users2:", users2)
users3 := make([]User, 0) users3 := make([]User, 0)
err = Orm.Find(&users3) err = Orm.Find(&users3)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("users3:", users3) fmt.Println("users3:", users3)
user4 := new(User) user4 := new(User)
has, err := Orm.Id(1).Get(user4) has, err := Orm.Id(1).Get(user4)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user4:", has, user4) fmt.Println("user4:", has, user4)
user4.Name = "xiaolunwen" user4.Name = "xiaolunwen"
_, err = Orm.Id(1).Update(user4) _, err = Orm.Id(1).Update(user4)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user4:", user4) fmt.Println("user4:", user4)
user5 := new(User) user5 := new(User)
has, err = Orm.Id(1).Get(user5) has, err = Orm.Id(1).Get(user5)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user5:", has, user5) fmt.Println("user5:", has, user5)
_, err = Orm.Id(1).Delete(new(User)) _, err = Orm.Id(1).Delete(new(User))
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
for { for {
user6 := new(User) user6 := new(User)
has, err = Orm.Id(1).Get(user6) has, err = Orm.Id(1).Get(user6)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println("user6:", has, user6) fmt.Println("user6:", has, user6)
} }
} }

View File

@ -1,105 +1,105 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db") return xorm.NewEngine("sqlite3", "./goroutine.db")
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}
func test(engine *xorm.Engine) { func test(engine *xorm.Engine) {
err := engine.CreateTables(u) err := engine.CreateTables(u)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
size := 500 size := 500
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(x int) { go func(x int) {
//x := i //x := i
err := engine.Ping() err := engine.Ping()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
if x+j < 2 { if x+j < 2 {
_, err = engine.Get(u) _, err = engine.Get(u)
} else if x+j < 4 { } else if x+j < 4 {
users := make([]User, 0) users := make([]User, 0)
err = engine.Find(&users) err = engine.Find(&users)
} else if x+j < 8 { } else if x+j < 8 {
_, err = engine.Count(u) _, err = engine.Count(u)
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
//_, err = engine.Id(1).Delete(u) //_, err = engine.Id(1).Delete(u)
_, err = engine.Delete(u) _, err = engine.Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
queue <- x queue <- x
return return
} }
} }
fmt.Printf("%v success!\n", x) fmt.Printf("%v success!\n", x)
} }
queue <- x queue <- x
}(i) }(i)
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
<-queue <-queue
} }
//conns := atomic.LoadInt32(&xorm.ConnectionNum) //conns := atomic.LoadInt32(&xorm.ConnectionNum)
//fmt.Println("connection number:", conns) //fmt.Println("connection number:", conns)
fmt.Println("end") fmt.Println("end")
} }
func main() { func main() {
fmt.Println("-----start sqlite go routines-----") fmt.Println("-----start sqlite go routines-----")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) cacher := xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)
engine.SetDefaultCacher(cacher) engine.SetDefaultCacher(cacher)
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("test end") fmt.Println("test end")
engine.Close() engine.Close()
fmt.Println("-----start mysql go routines-----") fmt.Println("-----start mysql go routines-----")
engine, err = mysqlEngine() engine, err = mysqlEngine()
engine.ShowSQL = true engine.ShowSQL = true
cacher = xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000) cacher = xorm.NewLRUCacher(xorm.NewMemoryStore(), 1000)
engine.SetDefaultCacher(cacher) engine.SetDefaultCacher(cacher)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer engine.Close() defer engine.Close()
test(engine) test(engine)
} }

View File

@ -1,76 +1,76 @@
package main package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
) )
type Status struct { type Status struct {
Name string Name string
Color string Color string
} }
var ( var (
Registed Status = Status{"Registed", "white"} Registed Status = Status{"Registed", "white"}
Approved Status = Status{"Approved", "green"} Approved Status = Status{"Approved", "green"}
Removed Status = Status{"Removed", "red"} Removed Status = Status{"Removed", "red"}
Statuses map[string]Status = map[string]Status{ Statuses map[string]Status = map[string]Status{
Registed.Name: Registed, Registed.Name: Registed,
Approved.Name: Approved, Approved.Name: Approved,
Removed.Name: Removed, Removed.Name: Removed,
} }
) )
func (s *Status) FromDB(bytes []byte) error { func (s *Status) FromDB(bytes []byte) error {
if r, ok := Statuses[string(bytes)]; ok { if r, ok := Statuses[string(bytes)]; ok {
*s = r *s = r
return nil return nil
} else { } else {
return errors.New("no this data") return errors.New("no this data")
} }
} }
func (s *Status) ToDB() ([]byte, error) { func (s *Status) ToDB() ([]byte, error) {
return []byte(s.Name), nil return []byte(s.Name), nil
} }
type User struct { type User struct {
Id int64 Id int64
Name string Name string
Status Status `xorm:"varchar(40)"` Status Status `xorm:"varchar(40)"`
} }
func main() { func main() {
f := "conversion.db" f := "conversion.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{1, "xlw", Registed}) _, err = Orm.Insert(&User{1, "xlw", Registed})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
users := make([]User, 0) users := make([]User, 0)
err = Orm.Find(&users) err = Orm.Find(&users)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(users) fmt.Println(users)
} }

View File

@ -1,66 +1,66 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
type LoginInfo struct { type LoginInfo struct {
Id int64 Id int64
IP string IP string
UserId int64 UserId int64
} }
type LoginInfo1 struct { type LoginInfo1 struct {
LoginInfo `xorm:"extends"` LoginInfo `xorm:"extends"`
UserName string UserName string
} }
func main() { func main() {
f := "derive.db" f := "derive.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer Orm.Close() defer Orm.Close()
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}, &LoginInfo{}) err = Orm.CreateTables(&User{}, &LoginInfo{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1}) _, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
info := LoginInfo{} info := LoginInfo{}
_, err = Orm.Id(1).Get(&info) _, err = Orm.Id(1).Get(&info)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(info) fmt.Println(info)
infos := make([]LoginInfo1, 0) infos := make([]LoginInfo1, 0)
err = Orm.Sql(`select *, (select name from user where id = login_info.user_id) as user_name from err = Orm.Sql(`select *, (select name from user where id = login_info.user_id) as user_name from
login_info limit 10`).Find(&infos) login_info limit 10`).Find(&infos)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(infos) fmt.Println(infos)
} }

View File

@ -1,106 +1,106 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
"runtime" "runtime"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db") return xorm.NewEngine("sqlite3", "./goroutine.db")
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}
func test(engine *xorm.Engine) { func test(engine *xorm.Engine) {
err := engine.CreateTables(u) err := engine.CreateTables(u)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
size := 500 size := 500
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(x int) { go func(x int) {
//x := i //x := i
err := engine.Test() err := engine.Test()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
err = engine.Map(u) err = engine.Map(u)
if err != nil { if err != nil {
fmt.Println("Map user failed") fmt.Println("Map user failed")
} else { } else {
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
if x+j < 2 { if x+j < 2 {
_, err = engine.Get(u) _, err = engine.Get(u)
} else if x+j < 4 { } else if x+j < 4 {
users := make([]User, 0) users := make([]User, 0)
err = engine.Find(&users) err = engine.Find(&users)
} else if x+j < 8 { } else if x+j < 8 {
_, err = engine.Count(u) _, err = engine.Count(u)
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
_, err = engine.Id(1).Delete(u) _, err = engine.Id(1).Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
queue <- x queue <- x
return return
} }
} }
fmt.Printf("%v success!\n", x) fmt.Printf("%v success!\n", x)
} }
} }
queue <- x queue <- x
}(i) }(i)
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
<-queue <-queue
} }
//conns := atomic.LoadInt32(&xorm.ConnectionNum) //conns := atomic.LoadInt32(&xorm.ConnectionNum)
//fmt.Println("connection number:", conns) //fmt.Println("connection number:", conns)
fmt.Println("end") fmt.Println("end")
} }
func main() { func main() {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
fmt.Println("-----start sqlite go routines-----") fmt.Println("-----start sqlite go routines-----")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("test end") fmt.Println("test end")
engine.Close() engine.Close()
fmt.Println("-----start mysql go routines-----") fmt.Println("-----start mysql go routines-----")
engine, err = mysqlEngine() engine, err = mysqlEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer engine.Close() defer engine.Close()
test(engine) test(engine)
} }

View File

@ -1,106 +1,106 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
xorm "github.com/lunny/xorm" xorm "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
"runtime" "runtime"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
os.Remove("./test.db") os.Remove("./test.db")
return xorm.NewEngine("sqlite3", "./goroutine.db") return xorm.NewEngine("sqlite3", "./goroutine.db")
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
var u *User = &User{} var u *User = &User{}
func test(engine *xorm.Engine) { func test(engine *xorm.Engine) {
err := engine.CreateTables(u) err := engine.CreateTables(u)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
engine.Pool.SetMaxConns(5) engine.Pool.SetMaxConns(5)
size := 1000 size := 1000
queue := make(chan int, size) queue := make(chan int, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(x int) { go func(x int) {
//x := i //x := i
err := engine.Test() err := engine.Test()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
err = engine.Map(u) err = engine.Map(u)
if err != nil { if err != nil {
fmt.Println("Map user failed") fmt.Println("Map user failed")
} else { } else {
for j := 0; j < 10; j++ { for j := 0; j < 10; j++ {
if x+j < 2 { if x+j < 2 {
_, err = engine.Get(u) _, err = engine.Get(u)
} else if x+j < 4 { } else if x+j < 4 {
users := make([]User, 0) users := make([]User, 0)
err = engine.Find(&users) err = engine.Find(&users)
} else if x+j < 8 { } else if x+j < 8 {
_, err = engine.Count(u) _, err = engine.Count(u)
} else if x+j < 16 { } else if x+j < 16 {
_, err = engine.Insert(&User{Name: "xlw"}) _, err = engine.Insert(&User{Name: "xlw"})
} else if x+j < 32 { } else if x+j < 32 {
_, err = engine.Id(1).Delete(u) _, err = engine.Id(1).Delete(u)
} }
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
queue <- x queue <- x
return return
} }
} }
fmt.Printf("%v success!\n", x) fmt.Printf("%v success!\n", x)
} }
} }
queue <- x queue <- x
}(i) }(i)
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
<-queue <-queue
} }
fmt.Println("end") fmt.Println("end")
} }
func main() { func main() {
runtime.GOMAXPROCS(2) runtime.GOMAXPROCS(2)
fmt.Println("create engine") fmt.Println("create engine")
engine, err := sqliteEngine() engine, err := sqliteEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
engine.ShowSQL = true engine.ShowSQL = true
fmt.Println(engine) fmt.Println(engine)
test(engine) test(engine)
fmt.Println("------------------------") fmt.Println("------------------------")
engine.Close() engine.Close()
engine, err = mysqlEngine() engine, err = mysqlEngine()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
defer engine.Close() defer engine.Close()
test(engine) test(engine)
} }

View File

@ -1,45 +1,45 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
func main() { func main() {
f := "pool.db" f := "pool.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
err = Orm.SetPool(NewSimpleConnectPool()) err = Orm.SetPool(NewSimpleConnectPool())
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}) err = Orm.CreateTables(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
_, err = Orm.Get(&User{}) _, err = Orm.Get(&User{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
break break
} }
} }
} }

View File

@ -1,54 +1,54 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"os" "os"
) )
type User struct { type User struct {
Id int64 Id int64
Name string Name string
} }
type LoginInfo struct { type LoginInfo struct {
Id int64 Id int64
IP string IP string
UserId int64 UserId int64
// timestamp should be updated by database, so only allow get from db // timestamp should be updated by database, so only allow get from db
TimeStamp string `xorm:"<-"` TimeStamp string `xorm:"<-"`
// assume // assume
Nonuse int `xorm:"->"` Nonuse int `xorm:"->"`
} }
func main() { func main() {
f := "singleMapping.db" f := "singleMapping.db"
os.Remove(f) os.Remove(f)
Orm, err := NewEngine("sqlite3", f) Orm, err := NewEngine("sqlite3", f)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = Orm.CreateTables(&User{}, &LoginInfo{}) err = Orm.CreateTables(&User{}, &LoginInfo{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
_, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1, "", 23}) _, err = Orm.Insert(&User{1, "xlw"}, &LoginInfo{1, "127.0.0.1", 1, "", 23})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
info := LoginInfo{} info := LoginInfo{}
_, err = Orm.Id(1).Get(&info) _, err = Orm.Id(1).Get(&info)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
fmt.Println(info) fmt.Println(info)
} }

View File

@ -1,92 +1,92 @@
package main package main
import ( import (
"fmt" "fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
type SyncUser2 struct { type SyncUser2 struct {
Id int64 Id int64
Name string `xorm:"unique"` Name string `xorm:"unique"`
Age int `xorm:"index"` Age int `xorm:"index"`
Title string Title string
Address string Address string
Genre string Genre string
Area string Area string
Date int Date int
} }
type SyncLoginInfo2 struct { type SyncLoginInfo2 struct {
Id int64 Id int64
IP string `xorm:"index"` IP string `xorm:"index"`
UserId int64 UserId int64
AddedCol int AddedCol int
// timestamp should be updated by database, so only allow get from db // timestamp should be updated by database, so only allow get from db
TimeStamp string TimeStamp string
// assume // assume
Nonuse int `xorm:"unique"` Nonuse int `xorm:"unique"`
Newa string `xorm:"index"` Newa string `xorm:"index"`
} }
func sync(engine *xorm.Engine) error { func sync(engine *xorm.Engine) error {
return engine.Sync(&SyncLoginInfo2{}, &SyncUser2{}) return engine.Sync(&SyncLoginInfo2{}, &SyncUser2{})
} }
func sqliteEngine() (*xorm.Engine, error) { func sqliteEngine() (*xorm.Engine, error) {
f := "sync.db" f := "sync.db"
//os.Remove(f) //os.Remove(f)
return xorm.NewEngine("sqlite3", f) return xorm.NewEngine("sqlite3", f)
} }
func mysqlEngine() (*xorm.Engine, error) { func mysqlEngine() (*xorm.Engine, error) {
return xorm.NewEngine("mysql", "root:@/test?charset=utf8") return xorm.NewEngine("mysql", "root:@/test?charset=utf8")
} }
func postgresEngine() (*xorm.Engine, error) { func postgresEngine() (*xorm.Engine, error) {
return xorm.NewEngine("postgres", "dbname=xorm_test sslmode=disable") return xorm.NewEngine("postgres", "dbname=xorm_test sslmode=disable")
} }
type engineFunc func() (*xorm.Engine, error) type engineFunc func() (*xorm.Engine, error)
func main() { func main() {
//engines := []engineFunc{sqliteEngine, mysqlEngine, postgresEngine} //engines := []engineFunc{sqliteEngine, mysqlEngine, postgresEngine}
//engines := []engineFunc{sqliteEngine} //engines := []engineFunc{sqliteEngine}
//engines := []engineFunc{mysqlEngine} //engines := []engineFunc{mysqlEngine}
engines := []engineFunc{postgresEngine} engines := []engineFunc{postgresEngine}
for _, enginefunc := range engines { for _, enginefunc := range engines {
Orm, err := enginefunc() Orm, err := enginefunc()
fmt.Println("--------", Orm.DriverName, "----------") fmt.Println("--------", Orm.DriverName, "----------")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
Orm.ShowSQL = true Orm.ShowSQL = true
err = sync(Orm) err = sync(Orm)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
_, err = Orm.Where("id > 0").Delete(&SyncUser2{}) _, err = Orm.Where("id > 0").Delete(&SyncUser2{})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
user := &SyncUser2{ user := &SyncUser2{
Name: "testsdf", Name: "testsdf",
Age: 15, Age: 15,
Title: "newsfds", Title: "newsfds",
Address: "fasfdsafdsaf", Address: "fasfdsafdsaf",
Genre: "fsafd", Genre: "fsafd",
Area: "fafdsafd", Area: "fafdsafd",
Date: 1000, Date: 1000,
} }
_, err = Orm.Insert(user) _, err = Orm.Insert(user)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
} }
} }

View File

@ -1,49 +0,0 @@
package xorm
import (
"fmt"
"strings"
)
// Filter is an interface to filter SQL
type Filter interface {
Do(sql string, session *Session) string
}
// PgSeqFilter filter SQL replace ?, ? ... to $1, $2 ...
type PgSeqFilter struct {
}
func (s *PgSeqFilter) Do(sql string, session *Session) string {
segs := strings.Split(sql, "?")
size := len(segs)
res := ""
for i, c := range segs {
if i < size-1 {
res += c + fmt.Sprintf("$%v", i+1)
}
}
res += segs[size-1]
return res
}
// QuoteFilter filter SQL replace ` to database's own quote character
type QuoteFilter struct {
}
func (s *QuoteFilter) Do(sql string, session *Session) string {
return strings.Replace(sql, "`", session.Engine.QuoteStr(), -1)
}
// IdFilter filter SQL replace (id) to primary key column name
type IdFilter struct {
}
func (i *IdFilter) Do(sql string, session *Session) string {
if session.Statement.RefTable != nil && len(session.Statement.RefTable.PrimaryKeys) == 1 {
sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1)
sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1)
return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1)
}
return sql
}

344
mysql.go
View File

@ -1,344 +0,0 @@
package xorm
import (
"crypto/tls"
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
type uri struct {
dbType string
proto string
host string
port string
dbName string
user string
passwd string
charset string
laddr string
raddr string
timeout time.Duration
}
type parser interface {
parse(driverName, dataSourceName string) (*uri, error)
}
type mysqlParser struct {
}
func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &uri{dbType: MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.dbName = match
case "params":
if len(match) > 0 {
kvs := strings.Split(match, "&")
for _, kv := range kvs {
splits := strings.Split(kv, "=")
if len(splits) == 2 {
switch splits[0] {
case "charset":
uri.charset = splits[1]
}
}
}
}
}
}
return uri, nil
}
type base struct {
parser parser
driverName string
dataSourceName string
*uri
}
func (b *base) init(parser parser, drivername, dataSourceName string) (err error) {
b.parser = parser
b.driverName, b.dataSourceName = drivername, dataSourceName
b.uri, err = b.parser.parse(b.driverName, b.dataSourceName)
return
}
func (b *base) URI() *uri {
return b.uri
}
func (b *base) DBType() string {
return b.uri.dbType
}
type mysql struct {
base
net string
addr string
params map[string]string
loc *time.Location
timeout time.Duration
tls *tls.Config
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
func (db *mysql) Init(drivername, uri string) error {
return db.base.init(&mysqlParser{}, drivername, uri)
}
func (db *mysql) SqlType(c *Column) string {
var res string
switch t := c.SQLType.Name; t {
case Bool:
res = TinyInt
case Serial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = Int
case BigSerial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = BigInt
case Bytea:
res = Blob
case TimeStampz:
res = Char
c.Length = 64
default:
res = t
}
var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0)
if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
}
return res
}
func (db *mysql) SupportInsertMany() bool {
return true
}
func (db *mysql) QuoteStr() string {
return "`"
}
func (db *mysql) SupportEngine() bool {
return true
}
func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT"
}
func (db *mysql) SupportCharset() bool {
return true
}
func (db *mysql) IndexOnTable() bool {
return true
}
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.dbName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args
}
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.dbName, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args
}
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.dbName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args
}
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.dbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil {
return nil, nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, nil, err
}
cols := make(map[string]*Column)
colSeq := make([]string, 0)
for _, record := range res {
col := new(Column)
col.Indexes = make(map[string]bool)
for name, content := range record {
switch name {
case "COLUMN_NAME":
col.Name = strings.Trim(string(content), "` ")
case "IS_NULLABLE":
if "YES" == string(content) {
col.Nullable = true
}
case "COLUMN_DEFAULT":
// add ''
col.Default = string(content)
case "COLUMN_TYPE":
cts := strings.Split(string(content), "(")
var len1, len2 int
if len(cts) == 2 {
idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil {
return nil, nil, err
}
if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1])
if err != nil {
return nil, nil, err
}
}
}
colName := cts[0]
colType := strings.ToUpper(colName)
col.Length = len1
col.Length2 = len2
if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2}
} else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
}
case "COLUMN_KEY":
key := string(content)
if key == "PRI" {
col.IsPrimaryKey = true
}
if key == "UNI" {
//col.is
}
case "EXTRA":
extra := string(content)
if extra == "auto_increment" {
col.IsAutoIncrement = true
}
}
}
if col.SQLType.IsText() {
if col.Default != "" {
col.Default = "'" + col.Default + "'"
}
}
cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
return colSeq, cols, nil
}
func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.dbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
tables := make([]*Table, 0)
for _, record := range res {
table := new(Table)
for name, content := range record {
switch name {
case "TABLE_NAME":
table.Name = strings.Trim(string(content), "` ")
case "ENGINE":
}
}
tables = append(tables, table)
}
return tables, nil
}
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.dbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil {
return nil, err
}
defer cnn.Close()
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
indexes := make(map[string]*Index, 0)
for _, record := range res {
var indexType int
var indexName, colName string
for name, content := range record {
switch name {
case "NON_UNIQUE":
if "YES" == string(content) || string(content) == "1" {
indexType = IndexType
} else {
indexType = UniqueType
}
case "INDEX_NAME":
indexName = string(content)
case "COLUMN_NAME":
colName = strings.Trim(string(content), "` ")
}
}
if indexName == "PRIMARY" {
continue
}
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)]
}
var index *Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(Index)
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
}
index.AddColumn(colName)
}
return indexes, nil
}

View File

@ -35,6 +35,30 @@ func TestMysql(t *testing.T) {
testAll3(engine, t) testAll3(engine, t)
} }
func TestMysqlSameMapper(t *testing.T) {
err := mysqlDdlImport()
if err != nil {
t.Error(err)
return
}
engine, err := NewEngine("mysql", "root:@/xorm_test3?charset=utf8")
defer engine.Close()
if err != nil {
t.Error(err)
return
}
engine.ShowSQL = showTestSql
engine.ShowErr = showTestSql
engine.ShowWarn = showTestSql
engine.ShowDebug = showTestSql
engine.SetMapper(SameMapper{})
testAll(engine, t)
testAll2(engine, t)
testAll3(engine, t)
}
func TestMysqlWithCache(t *testing.T) { func TestMysqlWithCache(t *testing.T) {
err := mysqlDdlImport() err := mysqlDdlImport()
if err != nil { if err != nil {

View File

@ -41,7 +41,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) {
} }
for _, filter := range rows.session.Engine.Filters { for _, filter := range rows.session.Engine.Filters {
sql = filter.Do(sql, session) sql = filter.Do(sql, session.Engine.dialect, rows.session.Statement.RefTable)
} }
rows.session.Engine.LogSQL(sql) rows.session.Engine.LogSQL(sql)

View File

@ -9,6 +9,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/lunny/xorm/core"
) )
// Struct Session keep a pointer to sql.DB and provides all execution of all // Struct Session keep a pointer to sql.DB and provides all execution of all
@ -108,7 +110,7 @@ func (session *Session) After(closures func(interface{})) *Session {
return session return session
} }
// Method Table can input a string or pointer to struct for special a table to operate. // Method core.Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session { func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableNameOrBean) session.Statement.Table(tableNameOrBean)
return session return session
@ -354,13 +356,10 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
} }
table := session.Engine.autoMapType(rType(obj)) table := session.Engine.autoMapType(rType(obj))
var col *core.Column
for key, data := range objMap { for key, data := range objMap {
key = strings.ToLower(key) if col = table.GetColumn(key); col == nil {
var col *Column session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.Columns()))
var ok bool
if col, ok = table.Columns[key]; !ok {
session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq))
continue continue
} }
@ -410,7 +409,7 @@ func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Resul
func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) {
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sqlStr = filter.Do(sqlStr, session) sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
session.Engine.LogSQL(sqlStr) session.Engine.LogSQL(sqlStr)
@ -596,14 +595,14 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf
return false, ErrCacheFailed return false, ErrCacheFailed
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sqlStr = filter.Do(sqlStr, session) sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
newsql := session.Statement.convertIdSql(sqlStr) newsql := session.Statement.convertIdSql(sqlStr)
if newsql == "" { if newsql == "" {
return false, ErrCacheFailed return false, ErrCacheFailed
} }
cacher := session.Statement.RefTable.Cacher cacher := session.Engine.getCacher(session.Statement.RefTable.Type)
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args) session.Engine.LogDebug("[xorm:cacheGet] find sql:", newsql, args)
ids, err := getCacheSql(cacher, tableName, newsql, args) ids, err := getCacheSql(cacher, tableName, newsql, args)
@ -679,7 +678,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sqlStr = filter.Do(sqlStr, session) sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
newsql := session.Statement.convertIdSql(sqlStr) newsql := session.Statement.convertIdSql(sqlStr)
@ -688,7 +687,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in
} }
table := session.Statement.RefTable table := session.Statement.RefTable
cacher := table.Cacher cacher := session.Engine.getCacher(t)
ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args) ids, err := getCacheSql(cacher, session.Statement.TableName(), newsql, args)
if err != nil { if err != nil {
//session.Engine.LogError(err) //session.Engine.LogError(err)
@ -892,7 +891,7 @@ func (session *Session) Get(bean interface{}) (bool, error) {
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if session.Statement.RefTable.Cacher != nil && session.Statement.UseCache { if cacher := session.Engine.getCacher(session.Statement.RefTable.Type); cacher != nil && session.Statement.UseCache {
has, err := session.cacheGet(bean, sqlStr, args...) has, err := session.cacheGet(bean, sqlStr, args...)
if err != ErrCacheFailed { if err != ErrCacheFailed {
return has, err return has, err
@ -1003,7 +1002,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
var table *Table var table *core.Table
if session.Statement.RefTable == nil { if session.Statement.RefTable == nil {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct { if sliceElementType.Elem().Kind() == reflect.Struct {
@ -1045,7 +1044,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
args = session.Statement.RawParams args = session.Statement.RawParams
} }
if table.Cacher != nil && if cacher := session.Engine.getCacher(table.Type); cacher != nil &&
session.Statement.UseCache && session.Statement.UseCache &&
!session.Statement.IsDistinct { !session.Statement.IsDistinct {
err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...) err = session.cacheFind(sliceElementType, sqlStr, rowsSlicePtr, args...)
@ -1092,25 +1091,41 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
fieldsCount := len(fields) fieldsCount := len(fields)
for rawRows.Next() { var newElemFunc func() reflect.Value
var newValue reflect.Value if sliceElementType.Kind() == reflect.Ptr {
newElemFunc = func() reflect.Value {
return reflect.New(sliceElementType.Elem())
}
} else {
newElemFunc = func() reflect.Value {
return reflect.New(sliceElementType)
}
}
var sliceValueSetFunc func(*reflect.Value)
if sliceValue.Kind() == reflect.Slice {
if sliceElementType.Kind() == reflect.Ptr { if sliceElementType.Kind() == reflect.Ptr {
newValue = reflect.New(sliceElementType.Elem()) sliceValueSetFunc = func(newValue *reflect.Value) {
} else {
newValue = reflect.New(sliceElementType)
}
err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface())
if err != nil {
return err
}
if sliceValue.Kind() == reflect.Slice {
if sliceElementType.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface()))) sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface())))
} else { }
} else {
sliceValueSetFunc = func(newValue *reflect.Value) {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
} }
} }
} }
for rawRows.Next() {
var newValue reflect.Value = newElemFunc()
if sliceValueSetFunc != nil {
err := session.row2Bean(rawRows, fields, fieldsCount, newValue.Interface())
if err != nil {
return err
}
sliceValueSetFunc(&newValue)
}
}
} else { } else {
resultsSlice, err := session.query(sqlStr, args...) resultsSlice, err := session.query(sqlStr, args...)
if err != nil { if err != nil {
@ -1236,9 +1251,9 @@ func (session *Session) isIndexExist2(tableName string, cols []string, unique bo
for _, index := range indexes { for _, index := range indexes {
if sliceEq(index.Cols, cols) { if sliceEq(index.Cols, cols) {
if unique { if unique {
return index.Type == UniqueType, nil return index.Type == core.UniqueType, nil
} else { } else {
return index.Type == IndexType, nil return index.Type == core.IndexType, nil
} }
} }
} }
@ -1256,7 +1271,7 @@ func (session *Session) addColumn(colName string) error {
} }
//fmt.Println(session.Statement.RefTable) //fmt.Println(session.Statement.RefTable)
col := session.Statement.RefTable.Columns[strings.ToLower(colName)] col := session.Statement.RefTable.GetColumn(colName)
sql, args := session.Statement.genAddColumnStr(col) sql, args := session.Statement.genAddColumnStr(col)
_, err = session.exec(sql, args...) _, err = session.exec(sql, args...)
return err return err
@ -1345,34 +1360,25 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err
return result, nil return result, nil
} }
func (session *Session) getField(dataStruct *reflect.Value, key string, table *Table) *reflect.Value { func (session *Session) getField(dataStruct *reflect.Value, key string, table *core.Table) *reflect.Value {
var col *core.Column
if col = table.GetColumn(key); col == nil {
session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.Columns()))
return nil
}
key = strings.ToLower(key) fieldValue, err := col.ValueOfV(dataStruct)
if _, ok := table.Columns[key]; !ok { if err != nil {
session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.ColumnsSeq)) session.Engine.LogError(err)
return nil return nil
} }
col := table.Columns[key]
fieldName := col.FieldName
fieldPath := strings.Split(fieldName, ".")
var fieldValue reflect.Value
if len(fieldPath) > 2 {
session.Engine.LogError("Unsupported mutliderive", fieldName)
return nil
} else if len(fieldPath) == 2 {
parentField := dataStruct.FieldByName(fieldPath[0])
if parentField.IsValid() {
fieldValue = parentField.FieldByName(fieldPath[1])
}
} else {
fieldValue = dataStruct.FieldByName(fieldName)
}
if !fieldValue.IsValid() || !fieldValue.CanSet() { if !fieldValue.IsValid() || !fieldValue.CanSet() {
session.Engine.LogWarn("table %v's column %v is not valid or cannot set", session.Engine.LogWarn("table %v's column %v is not valid or cannot set",
table.Name, key) table.Name, key)
return nil return nil
} }
return &fieldValue return fieldValue
} }
func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error { func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount int, bean interface{}) error {
@ -1395,7 +1401,6 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
for ii, key := range fields { for ii, key := range fields {
if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore //if row is null then ignore
@ -1404,7 +1409,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
continue continue
} }
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
if data, err := value2Bytes(&rawValue); err == nil { if data, err := value2Bytes(&rawValue); err == nil {
structConvert.FromDB(data) structConvert.FromDB(data)
} else { } else {
@ -1634,7 +1639,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
if !hasAssigned { if !hasAssigned {
data, err := value2Bytes(&rawValue) data, err := value2Bytes(&rawValue)
if err == nil { if err == nil {
session.bytes2Value(table.Columns[key], fieldValue, data) session.bytes2Value(table.GetColumn(key), fieldValue, data)
} else { } else {
session.Engine.LogError(err.Error()) session.Engine.LogError(err.Error())
} }
@ -1647,7 +1652,7 @@ func (session *Session) row2Bean(rows *sql.Rows, fields []string, fieldsCount in
func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) { func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) {
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
*sqlStr = filter.Do(*sqlStr, session) *sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
session.Engine.LogSQL(*sqlStr) session.Engine.LogSQL(*sqlStr)
@ -1763,7 +1768,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
colNames := make([]string, 0) colNames := make([]string, 0)
colMultiPlaces := make([]string, 0) colMultiPlaces := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
cols := make([]*Column, 0) cols := make([]*core.Column, 0)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
elemValue := sliceValue.Index(i).Interface() elemValue := sliceValue.Index(i).Interface()
@ -1781,12 +1786,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
// -- // --
if i == 0 { if i == 0 {
for _, col := range table.Columns { for _, col := range table.Columns() {
fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName)
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if col.MapType == ONLYFROMDB { if col.MapType == core.ONLYFROMDB {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.Statement.ColumnStr != "" {
@ -1814,7 +1819,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
if col.IsAutoIncrement && fieldValue.Int() == 0 { if col.IsAutoIncrement && fieldValue.Int() == 0 {
continue continue
} }
if col.MapType == ONLYFROMDB { if col.MapType == core.ONLYFROMDB {
continue continue
} }
if session.Statement.ColumnStr != "" { if session.Statement.ColumnStr != "" {
@ -1853,7 +1858,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error
return 0, err return 0, err
} }
if table.Cacher != nil && session.Statement.UseCache { if cacher := session.Engine.getCacher(table.Type); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.Statement.TableName())
} }
@ -1904,7 +1909,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
return session.innerInsertMulti(rowsSlicePtr) return session.innerInsertMulti(rowsSlicePtr)
} }
func (session *Session) byte2Time(col *Column, data []byte) (outTime time.Time, outErr error) { func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.Time, outErr error) {
sdata := strings.TrimSpace(string(data)) sdata := strings.TrimSpace(string(data))
var x time.Time var x time.Time
var err error var err error
@ -1929,7 +1934,7 @@ func (session *Session) byte2Time(col *Column, data []byte) (outTime time.Time,
x, err = time.Parse("2006-01-02 15:04:05", sdata) x, err = time.Parse("2006-01-02 15:04:05", sdata)
} else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' {
x, err = time.Parse("2006-01-02", sdata) x, err = time.Parse("2006-01-02", sdata)
} else if col.SQLType.Name == Time { } else if col.SQLType.Name == core.Time {
if strings.Contains(sdata, " ") { if strings.Contains(sdata, " ") {
ssd := strings.Split(sdata, " ") ssd := strings.Split(sdata, " ")
sdata = ssd[1] sdata = ssd[1]
@ -1937,7 +1942,7 @@ func (session *Session) byte2Time(col *Column, data []byte) (outTime time.Time,
sdata = strings.TrimSpace(sdata) sdata = strings.TrimSpace(sdata)
//fmt.Println(sdata) //fmt.Println(sdata)
if session.Engine.dialect.DBType() == MYSQL && len(sdata) > 8 { if session.Engine.dialect.DBType() == core.MYSQL && len(sdata) > 8 {
sdata = sdata[len(sdata)-8:] sdata = sdata[len(sdata)-8:]
} }
//fmt.Println(sdata) //fmt.Println(sdata)
@ -1957,8 +1962,8 @@ func (session *Session) byte2Time(col *Column, data []byte) (outTime time.Time,
} }
// convert a db data([]byte) to a field value // convert a db data([]byte) to a field value
func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data []byte) error { func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, data []byte) error {
if structConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
return structConvert.FromDB(data) return structConvert.FromDB(data)
} }
@ -2018,8 +2023,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var x int64 var x int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == Bit && if col.SQLType.Name == core.Bit &&
session.Engine.dialect.DBType() == MYSQL { session.Engine.dialect.DBType() == core.MYSQL {
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
} else { } else {
@ -2199,7 +2204,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var x int64 var x int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName, "mysql") { strings.Contains(session.Engine.DriverName, "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int64(data[0]) x = int64(data[0])
@ -2225,7 +2230,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName, "mysql") { strings.Contains(session.Engine.DriverName, "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int(data[0]) x = int(data[0])
@ -2254,7 +2259,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName, "mysql") { strings.Contains(session.Engine.DriverName, "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int32(data[0]) x = int32(data[0])
@ -2283,7 +2288,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName, "mysql") { strings.Contains(session.Engine.DriverName, "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int8(data[0]) x = int8(data[0])
@ -2312,7 +2317,7 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var x1 int64 var x1 int64
var err error var err error
// for mysql, when use bit, it returned \x01 // for mysql, when use bit, it returned \x01
if col.SQLType.Name == Bit && if col.SQLType.Name == core.Bit &&
strings.Contains(session.Engine.DriverName, "mysql") { strings.Contains(session.Engine.DriverName, "mysql") {
if len(data) == 1 { if len(data) == 1 {
x = int16(data[0]) x = int16(data[0])
@ -2345,9 +2350,9 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
// convert a field value of a struct to interface for put into db // convert a field value of a struct to interface for put into db
func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (interface{}, error) { func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Value) (interface{}, error) {
if fieldValue.CanAddr() { if fieldValue.CanAddr() {
if fieldConvert, ok := fieldValue.Addr().Interface().(Conversion); ok { if fieldConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
data, err := fieldConvert.ToDB() data, err := fieldConvert.ToDB()
if err != nil { if err != nil {
return 0, err return 0, err
@ -2384,19 +2389,19 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) (
case reflect.Struct: case reflect.Struct:
if fieldType == reflect.TypeOf(c_TIME_DEFAULT) { if fieldType == reflect.TypeOf(c_TIME_DEFAULT) {
t := fieldValue.Interface().(time.Time) t := fieldValue.Interface().(time.Time)
if session.Engine.dialect.DBType() == MSSQL { if session.Engine.dialect.DBType() == core.MSSQL {
if t.IsZero() { if t.IsZero() {
return nil, nil return nil, nil
} }
} }
if col.SQLType.Name == Time { if col.SQLType.Name == core.Time {
//s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700") //s := fieldValue.Interface().(time.Time).Format("2006-01-02 15:04:05 -0700")
s := fieldValue.Interface().(time.Time).Format(time.RFC3339) s := fieldValue.Interface().(time.Time).Format(time.RFC3339)
return s[11:19], nil return s[11:19], nil
} else if col.SQLType.Name == Date { } else if col.SQLType.Name == core.Date {
return fieldValue.Interface().(time.Time).Format("2006-01-02"), nil return fieldValue.Interface().(time.Time).Format("2006-01-02"), nil
} else if col.SQLType.Name == TimeStampz { } else if col.SQLType.Name == core.TimeStampz {
if session.Engine.dialect.DBType() == MSSQL { if session.Engine.dialect.DBType() == core.MSSQL {
tf := t.Format("2006-01-02T15:04:05.9999999Z07:00") tf := t.Format("2006-01-02T15:04:05.9999999Z07:00")
return tf, nil return tf, nil
} }
@ -2470,7 +2475,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
} }
// -- // --
colNames, args, err := table.genCols(session, bean, false, false) colNames, args, err := genCols(table, session, bean, false, false)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -2519,7 +2524,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
// for postgres, many of them didn't implement lastInsertId, so we should // for postgres, many of them didn't implement lastInsertId, so we should
// implemented it ourself. // implemented it ourself.
if session.Engine.DriverName != POSTGRES || table.AutoIncrement == "" { if session.Engine.DriverName != core.POSTGRES || table.AutoIncrement == "" {
res, err := session.exec(sqlStr, args...) res, err := session.exec(sqlStr, args...)
if err != nil { if err != nil {
return 0, err return 0, err
@ -2527,13 +2532,15 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
handleAfterInsertProcessorFunc(bean) handleAfterInsertProcessorFunc(bean)
} }
if table.Cacher != nil && session.Statement.UseCache { if cacher := session.Engine.getCacher(table.Type); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.Statement.TableName())
} }
if table.Version != "" && session.Statement.checkVersion { if table.Version != "" && session.Statement.checkVersion {
verValue := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if verValue.IsValid() && verValue.CanSet() { if err != nil {
session.Engine.LogError(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) verValue.SetInt(1)
} }
} }
@ -2548,8 +2555,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return res.RowsAffected() return res.RowsAffected()
} }
aiValue := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { if err != nil {
session.Engine.LogError(err)
}
if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() {
return res.RowsAffected() return res.RowsAffected()
} }
@ -2580,13 +2591,15 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
handleAfterInsertProcessorFunc(bean) handleAfterInsertProcessorFunc(bean)
} }
if table.Cacher != nil && session.Statement.UseCache { if cacher := session.Engine.getCacher(table.Type); cacher != nil && session.Statement.UseCache {
session.cacheInsert(session.Statement.TableName()) session.cacheInsert(session.Statement.TableName())
} }
if table.Version != "" && session.Statement.checkVersion { if table.Version != "" && session.Statement.checkVersion {
verValue := table.VersionColumn().ValueOf(bean) verValue, err := table.VersionColumn().ValueOf(bean)
if verValue.IsValid() && verValue.CanSet() { if err != nil {
session.Engine.LogError(err)
} else if verValue.IsValid() && verValue.CanSet() {
verValue.SetInt(1) verValue.SetInt(1)
} }
} }
@ -2601,8 +2614,12 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return 1, err return 1, err
} }
aiValue := table.AutoIncrColumn().ValueOf(bean) aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() { if err != nil {
session.Engine.LogError(err)
}
if aiValue == nil || !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() {
return 1, nil return 1, nil
} }
@ -2659,9 +2676,9 @@ func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) {
//TODO: for postgres only, if any other database? //TODO: for postgres only, if any other database?
var paraStr string var paraStr string
if statement.Engine.dialect.DBType() == POSTGRES { if statement.Engine.dialect.DBType() == core.POSTGRES {
paraStr = "$" paraStr = "$"
} else if statement.Engine.dialect.DBType() == MSSQL { } else if statement.Engine.dialect.DBType() == core.MSSQL {
paraStr = ":" paraStr = ":"
} }
@ -2687,7 +2704,7 @@ func (session *Session) cacheInsert(tables ...string) error {
} }
table := session.Statement.RefTable table := session.Statement.RefTable
cacher := table.Cacher cacher := session.Engine.getCacher(table.Type)
for _, t := range tables { for _, t := range tables {
session.Engine.LogDebug("cache clear:", t) session.Engine.LogDebug("cache clear:", t)
@ -2707,7 +2724,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
return ErrCacheFailed return ErrCacheFailed
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
newsql = filter.Do(newsql, session) newsql = filter.Do(newsql, session.Engine.dialect, session.Statement.RefTable)
} }
session.Engine.LogDebug("[xorm:cacheUpdate] new sql", oldhead, newsql) session.Engine.LogDebug("[xorm:cacheUpdate] new sql", oldhead, newsql)
@ -2721,7 +2738,7 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
} }
} }
table := session.Statement.RefTable table := session.Statement.RefTable
cacher := table.Cacher cacher := session.Engine.getCacher(table.Type)
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:]) session.Engine.LogDebug("[xorm:cacheUpdate] get cache sql", newsql, args[nStart:])
ids, err := getCacheSql(cacher, tableName, newsql, args[nStart:]) ids, err := getCacheSql(cacher, tableName, newsql, args[nStart:])
@ -2777,13 +2794,17 @@ func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error {
return ErrCacheFailed return ErrCacheFailed
} }
if col, ok := table.Columns[strings.ToLower(colName)]; ok { if col := table.GetColumn(colName); col != nil {
fieldValue := col.ValueOf(bean) fieldValue, err := col.ValueOf(bean)
session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface()) if err != nil {
if col.IsVersion && session.Statement.checkVersion { session.Engine.LogError(err)
fieldValue.SetInt(fieldValue.Int() + 1)
} else { } else {
fieldValue.Set(reflect.ValueOf(args[idx])) session.Engine.LogDebug("[xorm:cacheUpdate] set bean field", bean, colName, fieldValue.Interface())
if col.IsVersion && session.Statement.checkVersion {
fieldValue.SetInt(fieldValue.Int() + 1)
} else {
fieldValue.Set(reflect.ValueOf(args[idx]))
}
} }
} else { } else {
session.Engine.LogError("[xorm:cacheUpdate] ERROR: column %v is not table %v's", session.Engine.LogError("[xorm:cacheUpdate] ERROR: column %v is not table %v's",
@ -2820,7 +2841,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var colNames []string var colNames []string
var args []interface{} var args []interface{}
var table *Table var table *core.Table
// handle before update processors // handle before update processors
for _, closure := range session.beforeClosures { for _, closure := range session.beforeClosures {
@ -2840,7 +2861,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
colNames, args = buildConditions(session.Engine, table, bean, false, false, colNames, args = buildConditions(session.Engine, table, bean, false, false,
false, false, session.Statement.allUseBool, session.Statement.boolColumnMap) false, false, session.Statement.allUseBool, session.Statement.boolColumnMap)
} else { } else {
colNames, args, err = table.genCols(session, bean, true, true) colNames, args, err = genCols(table, session, bean, true, true)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -2917,7 +2938,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1", session.Engine.Quote(table.Version)+" = "+session.Engine.Quote(table.Version)+" + 1",
condition) condition)
condiArgs = append(condiArgs, table.VersionColumn().ValueOf(bean).Interface()) verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
return 0, err
}
condiArgs = append(condiArgs, verValue.Interface())
} else { } else {
if condition != "" { if condition != "" {
condition = "WHERE " + condition condition = "WHERE " + condition
@ -2946,10 +2972,10 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
return 0, err return 0, err
} }
if table.Cacher != nil && session.Statement.UseCache { if cacher := session.Engine.getCacher(t); cacher != nil && session.Statement.UseCache {
//session.cacheUpdate(sqlStr, args...) //session.cacheUpdate(sqlStr, args...)
table.Cacher.ClearIds(session.Statement.TableName()) cacher.ClearIds(session.Statement.TableName())
table.Cacher.ClearBeans(session.Statement.TableName()) cacher.ClearBeans(session.Statement.TableName())
} }
// handle after update processors // handle after update processors
@ -2990,7 +3016,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
} }
for _, filter := range session.Engine.Filters { for _, filter := range session.Engine.Filters {
sqlStr = filter.Do(sqlStr, session) sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable)
} }
newsql := session.Statement.convertIdSql(sqlStr) newsql := session.Statement.convertIdSql(sqlStr)
@ -2998,7 +3024,7 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error {
return ErrCacheFailed return ErrCacheFailed
} }
cacher := session.Statement.RefTable.Cacher cacher := session.Engine.getCacher(session.Statement.RefTable.Type)
tableName := session.Statement.TableName() tableName := session.Statement.TableName()
ids, err := getCacheSql(cacher, tableName, newsql, args) ids, err := getCacheSql(cacher, tableName, newsql, args)
if err != nil { if err != nil {
@ -3090,7 +3116,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
args = append(session.Statement.Params, args...) args = append(session.Statement.Params, args...)
if table.Cacher != nil && session.Statement.UseCache { if cacher := session.Engine.getCacher(session.Statement.RefTable.Type); cacher != nil && session.Statement.UseCache {
session.cacheDelete(sqlStr, args...) session.cacheDelete(sqlStr, args...)
} }

View File

@ -8,11 +8,13 @@ import (
"encoding/json" "encoding/json"
"strings" "strings"
"time" "time"
"github.com/lunny/xorm/core"
) )
// statement save all the sql info for executing SQL // statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *Table RefTable *core.Table
Engine *Engine Engine *Engine
Start int Start int
LimitN int LimitN int
@ -64,7 +66,7 @@ func (statement *Statement) Init() {
statement.RawSQL = "" statement.RawSQL = ""
statement.RawParams = make([]interface{}, 0) statement.RawParams = make([]interface{}, 0)
statement.BeanArgs = make([]interface{}, 0) statement.BeanArgs = make([]interface{}, 0)
statement.UseCache = statement.Engine.UseCache statement.UseCache = true
statement.UseAutoTime = true statement.UseAutoTime = true
statement.IsDistinct = false statement.IsDistinct = false
statement.allUseBool = false statement.allUseBool = false
@ -237,13 +239,13 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
}*/ }*/
// Auto generating conditions according a struct // Auto generating conditions according a struct
func buildConditions(engine *Engine, table *Table, bean interface{}, func buildConditions(engine *Engine, table *core.Table, bean interface{},
includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool,
boolColumnMap map[string]bool) ([]string, []interface{}) { boolColumnMap map[string]bool) ([]string, []interface{}) {
colNames := make([]string, 0) colNames := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns() {
if !includeVersion && col.IsVersion { if !includeVersion && col.IsVersion {
continue continue
} }
@ -255,10 +257,16 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
} }
// //
//fmt.Println(engine.dialect.DBType(), Text) //fmt.Println(engine.dialect.DBType(), Text)
if engine.dialect.DBType() == MSSQL && col.SQLType.Name == Text { if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text {
continue continue
} }
fieldValue := col.ValueOf(bean) fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
engine.LogError(err)
continue
}
fieldValue := *fieldValuePtr
fieldType := reflect.TypeOf(fieldValue.Interface()) fieldType := reflect.TypeOf(fieldValue.Interface())
requiredField := false requiredField := false
@ -323,10 +331,10 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
continue continue
} }
var str string var str string
if col.SQLType.Name == Time { if col.SQLType.Name == core.Time {
s := t.UTC().Format("2006-01-02 15:04:05") s := t.UTC().Format("2006-01-02 15:04:05")
val = s[11:19] val = s[11:19]
} else if col.SQLType.Name == Date { } else if col.SQLType.Name == core.Date {
str = t.Format("2006-01-02") str = t.Format("2006-01-02")
val = str val = str
} else { } else {
@ -510,7 +518,7 @@ func (statement *Statement) Distinct(columns ...string) *Statement {
func (statement *Statement) Cols(columns ...string) *Statement { func (statement *Statement) Cols(columns ...string) *Statement {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
statement.columnMap[nc] = true statement.columnMap[strings.ToLower(nc)] = true
} }
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
return statement return statement
@ -521,7 +529,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement {
if len(columns) > 0 { if len(columns) > 0 {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
statement.boolColumnMap[nc] = true statement.boolColumnMap[strings.ToLower(nc)] = true
} }
} else { } else {
statement.allUseBool = true statement.allUseBool = true
@ -533,7 +541,7 @@ func (statement *Statement) UseBool(columns ...string) *Statement {
func (statement *Statement) Omit(columns ...string) { func (statement *Statement) Omit(columns ...string) {
newColumns := col2NewCols(columns...) newColumns := col2NewCols(columns...)
for _, nc := range newColumns { for _, nc := range newColumns {
statement.columnMap[nc] = false statement.columnMap[strings.ToLower(nc)] = false
} }
statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
@ -584,13 +592,13 @@ func (statement *Statement) Having(conditions string) *Statement {
func (statement *Statement) genColumnStr() string { func (statement *Statement) genColumnStr() string {
table := statement.RefTable table := statement.RefTable
colNames := make([]string, 0) colNames := make([]string, 0)
for _, col := range table.Columns { for _, col := range table.Columns() {
if statement.OmitStr != "" { if statement.OmitStr != "" {
if _, ok := statement.columnMap[col.Name]; ok { if _, ok := statement.columnMap[strings.ToLower(col.Name)]; ok {
continue continue
} }
} }
if col.MapType == ONLYTODB { if col.MapType == core.ONLYTODB {
continue continue
} }
colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name)) colNames = append(colNames, statement.Engine.Quote(statement.TableName())+"."+statement.Engine.Quote(col.Name))
@ -599,54 +607,8 @@ func (statement *Statement) genColumnStr() string {
} }
func (statement *Statement) genCreateTableSQL() string { func (statement *Statement) genCreateTableSQL() string {
var sql string return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.AltTableName,
if statement.Engine.dialect.DBType() == MSSQL { statement.StoreEngine, statement.Charset)
sql = "IF NOT EXISTS (SELECT [name] FROM sys.tables WHERE [name] = '" + statement.TableName() + "' ) CREATE TABLE"
} else {
sql = "CREATE TABLE IF NOT EXISTS "
}
sql += statement.Engine.Quote(statement.TableName()) + " ("
pkList := []string{}
for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[strings.ToLower(colName)]
if col.IsPrimaryKey {
pkList = append(pkList, col.Name)
}
}
statement.Engine.LogDebug("len:", len(pkList))
for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[strings.ToLower(colName)]
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(statement.Engine.dialect)
} else {
sql += col.stringNoPk(statement.Engine.dialect)
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += strings.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" {
sql += " ENGINE=" + statement.StoreEngine
}
if statement.Engine.dialect.SupportCharset() {
if statement.Charset != "" {
sql += " DEFAULT CHARSET " + statement.Charset
} else if statement.Engine.dialect.URI().charset != "" {
sql += " DEFAULT CHARSET " + statement.Engine.dialect.URI().charset
}
}
sql += ";"
return sql
} }
func indexName(tableName, idxName string) string { func indexName(tableName, idxName string) string {
@ -658,7 +620,7 @@ func (s *Statement) genIndexSQL() []string {
tbName := s.TableName() tbName := s.TableName()
quote := s.Engine.Quote quote := s.Engine.Quote
for idxName, index := range s.RefTable.Indexes { for idxName, index := range s.RefTable.Indexes {
if index.Type == IndexType { if index.Type == core.IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(index.Cols, quote(",")))) quote(tbName), quote(strings.Join(index.Cols, quote(","))))
sqls = append(sqls, sql) sqls = append(sqls, sql)
@ -676,7 +638,7 @@ func (s *Statement) genUniqueSQL() []string {
tbName := s.TableName() tbName := s.TableName()
quote := s.Engine.Quote quote := s.Engine.Quote
for idxName, unique := range s.RefTable.Indexes { for idxName, unique := range s.RefTable.Indexes {
if unique.Type == UniqueType { if unique.Type == core.UniqueType {
sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)), sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)),
quote(tbName), quote(strings.Join(unique.Cols, quote(",")))) quote(tbName), quote(strings.Join(unique.Cols, quote(","))))
sqls = append(sqls, sql) sqls = append(sqls, sql)
@ -689,9 +651,9 @@ func (s *Statement) genDelIndexSQL() []string {
var sqls []string = make([]string, 0) var sqls []string = make([]string, 0)
for idxName, index := range s.RefTable.Indexes { for idxName, index := range s.RefTable.Indexes {
var rIdxName string var rIdxName string
if index.Type == UniqueType { if index.Type == core.UniqueType {
rIdxName = uniqueName(s.TableName(), idxName) rIdxName = uniqueName(s.TableName(), idxName)
} else if index.Type == IndexType { } else if index.Type == core.IndexType {
rIdxName = indexName(s.TableName(), idxName) rIdxName = indexName(s.TableName(), idxName)
} }
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName)) sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName))
@ -704,7 +666,7 @@ func (s *Statement) genDelIndexSQL() []string {
} }
func (s *Statement) genDropSQL() string { func (s *Statement) genDropSQL() string {
if s.Engine.dialect.DBType() == MSSQL { if s.Engine.dialect.DBType() == core.MSSQL {
return "IF EXISTS (SELECT * FROM sysobjects WHERE id = object_id(N'" + return "IF EXISTS (SELECT * FROM sysobjects WHERE id = object_id(N'" +
s.TableName() + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1) " + s.TableName() + "') and OBJECTPROPERTY(id, N'IsUserTable') = 1) " +
"DROP TABLE " + s.Engine.Quote(s.TableName()) + ";" "DROP TABLE " + s.Engine.Quote(s.TableName()) + ";"
@ -731,7 +693,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{})
return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...) return statement.genSelectSql(columnStr), append(statement.Params, statement.BeanArgs...)
} }
func (s *Statement) genAddColumnStr(col *Column) (string, []interface{}) { func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
quote := s.Engine.Quote quote := s.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()), sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()),
col.String(s.Engine.dialect)) col.String(s.Engine.dialect))
@ -804,13 +766,14 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" { if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
} }
if statement.Engine.dialect.DBType() != MSSQL { if statement.Engine.dialect.DBType() != core.MSSQL {
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start)
} else if statement.LimitN > 0 { } else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
} }
} else { } else {
//TODO: for mssql, should handler limit.
/*SELECT * FROM ( /*SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (ORDER BY id desc) as row FROM "userinfo" SELECT *, ROW_NUMBER() OVER (ORDER BY id desc) as row FROM "userinfo"
) a WHERE row > [start] and row <= [start+limit] order by id desc*/ ) a WHERE row > [start] and row <= [start+limit] order by id desc*/
@ -823,11 +786,12 @@ func (statement *Statement) processIdParam() {
if statement.IdParam != nil { if statement.IdParam != nil {
i := 0 i := 0
colCnt := len(statement.RefTable.ColumnsSeq) columns := statement.RefTable.ColumnsSeq()
colCnt := len(columns)
for _, elem := range *(statement.IdParam) { for _, elem := range *(statement.IdParam) {
for ; i < colCnt; i++ { for ; i < colCnt; i++ {
colName := statement.RefTable.ColumnsSeq[i] colName := columns[i]
col := statement.RefTable.Columns[strings.ToLower(colName)] col := statement.RefTable.GetColumn(colName)
if col.IsPrimaryKey { if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), elem) statement.And(fmt.Sprintf("%v=?", col.Name), elem)
i++ i++
@ -840,8 +804,8 @@ func (statement *Statement) processIdParam() {
// as empty string for now, so this will result sql exec failed instead of unexpected // as empty string for now, so this will result sql exec failed instead of unexpected
// false update/delete // false update/delete
for ; i < colCnt; i++ { for ; i < colCnt; i++ {
colName := statement.RefTable.ColumnsSeq[i] colName := columns[i]
col := statement.RefTable.Columns[strings.ToLower(colName)] col := statement.RefTable.GetColumn(colName)
if col.IsPrimaryKey { if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), "") statement.And(fmt.Sprintf("%v=?", col.Name), "")
} }

409
table.go
View File

@ -2,120 +2,12 @@ package xorm
import ( import (
"reflect" "reflect"
"sort"
"strings" "strings"
"time" "time"
"github.com/lunny/xorm/core"
) )
// xorm SQL types
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
func (s *SQLType) IsText() bool {
return s.Name == Char || s.Name == Varchar || s.Name == TinyText ||
s.Name == Text || s.Name == MediumText || s.Name == LongText
}
func (s *SQLType) IsBlob() bool {
return (s.Name == TinyBlob) || (s.Name == Blob) ||
s.Name == MediumBlob || s.Name == LongBlob ||
s.Name == Binary || s.Name == VarBinary || s.Name == Bytea
}
const ()
var (
Bit = "BIT"
TinyInt = "TINYINT"
SmallInt = "SMALLINT"
MediumInt = "MEDIUMINT"
Int = "INT"
Integer = "INTEGER"
BigInt = "BIGINT"
Char = "CHAR"
Varchar = "VARCHAR"
TinyText = "TINYTEXT"
Text = "TEXT"
MediumText = "MEDIUMTEXT"
LongText = "LONGTEXT"
Date = "DATE"
DateTime = "DATETIME"
Time = "TIME"
TimeStamp = "TIMESTAMP"
TimeStampz = "TIMESTAMPZ"
Decimal = "DECIMAL"
Numeric = "NUMERIC"
Real = "REAL"
Float = "FLOAT"
Double = "DOUBLE"
Binary = "BINARY"
VarBinary = "VARBINARY"
TinyBlob = "TINYBLOB"
Blob = "BLOB"
MediumBlob = "MEDIUMBLOB"
LongBlob = "LONGBLOB"
Bytea = "BYTEA"
Bool = "BOOL"
Serial = "SERIAL"
BigSerial = "BIGSERIAL"
sqlTypes = map[string]bool{
Bit: true,
TinyInt: true,
SmallInt: true,
MediumInt: true,
Int: true,
Integer: true,
BigInt: true,
Char: true,
Varchar: true,
TinyText: true,
Text: true,
MediumText: true,
LongText: true,
Date: true,
DateTime: true,
Time: true,
TimeStamp: true,
TimeStampz: true,
Decimal: true,
Numeric: true,
Binary: true,
VarBinary: true,
Real: true,
Float: true,
Double: true,
TinyBlob: true,
Blob: true,
MediumBlob: true,
LongBlob: true,
Bytea: true,
Bool: true,
Serial: true,
BigSerial: true,
}
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
)
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparision
var ( var (
c_EMPTY_STRING string c_EMPTY_STRING string
c_BOOL_DEFAULT bool c_BOOL_DEFAULT bool
@ -137,290 +29,28 @@ var (
c_TIME_DEFAULT time.Time c_TIME_DEFAULT time.Time
) )
func Type2SQLType(t reflect.Type) (st SQLType) { func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
st = SQLType{Int, 0, 0}
case reflect.Int64, reflect.Uint64:
st = SQLType{BigInt, 0, 0}
case reflect.Float32:
st = SQLType{Float, 0, 0}
case reflect.Float64:
st = SQLType{Double, 0, 0}
case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) {
st = SQLType{Blob, 0, 0}
} else {
st = SQLType{Text, 0, 0}
}
case reflect.Bool:
st = SQLType{Bool, 0, 0}
case reflect.String:
st = SQLType{Varchar, 255, 0}
case reflect.Struct:
if t == reflect.TypeOf(c_TIME_DEFAULT) {
st = SQLType{DateTime, 0, 0}
} else {
// TODO need to handle association struct
st = SQLType{Text, 0, 0}
}
case reflect.Ptr:
st, _ = ptrType2SQLType(t)
default:
st = SQLType{Text, 0, 0}
}
return
}
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
has = true
switch t {
case reflect.TypeOf(&c_EMPTY_STRING):
st = SQLType{Varchar, 255, 0}
return
case reflect.TypeOf(&c_BOOL_DEFAULT):
st = SQLType{Bool, 0, 0}
case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
st = SQLType{Varchar, 64, 0}
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
st = SQLType{Float, 0, 0}
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
st = SQLType{Double, 0, 0}
case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
st = SQLType{BigInt, 0, 0}
case reflect.TypeOf(&c_TIME_DEFAULT):
st = SQLType{DateTime, 0, 0}
case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
st = SQLType{Int, 0, 0}
default:
has = false
}
return
}
// default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name)
switch name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1)
case BigInt, BigSerial:
return reflect.TypeOf(int64(1))
case Float, Real:
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, Varchar, TinyText, Text, MediumText, LongText:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{})
case Bool:
return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz:
return reflect.TypeOf(c_TIME_DEFAULT)
case Decimal, Numeric:
return reflect.TypeOf("")
default:
return reflect.TypeOf("")
}
}
const (
IndexType = iota + 1
UniqueType
)
// database index
type Index struct {
Name string
Type int
Cols []string
}
// add columns which will be composite index
func (index *Index) AddColumn(cols ...string) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
// new an index
func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)}
}
const (
TWOSIDES = iota + 1
ONLYTODB
ONLYFROMDB
)
// database column
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Length2 int
Nullable bool
Default string
Indexes map[string]bool
IsPrimaryKey bool
IsAutoIncrement bool
MapType int
IsCreated bool
IsUpdated bool
IsCascade bool
IsVersion bool
}
// generate column description string according dialect
func (col *Column) String(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
func (col *Column) stringNoPk(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " "
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
// return col's filed of struct's value
func (col *Column) ValueOf(bean interface{}) reflect.Value {
var fieldValue reflect.Value
if strings.Contains(col.FieldName, ".") {
fields := strings.Split(col.FieldName, ".")
if len(fields) > 2 {
return reflect.ValueOf(nil)
}
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(fields[0])
fieldValue = fieldValue.FieldByName(fields[1])
} else {
fieldValue = reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
}
return fieldValue
}
// database table
type Table struct {
Name string
Type reflect.Type
ColumnsSeq []string
Columns map[string]*Column
Indexes map[string]*Index
PrimaryKeys []string
AutoIncrement string
Created map[string]bool
Updated string
Version string
Cacher Cacher
}
/*
func NewTable(name string, t reflect.Type) *Table {
return &Table{Name: name, Type: t,
ColumnsSeq: make([]string, 0),
Columns: make(map[string]*Column),
Indexes: make(map[string]*Index),
Created: make(map[string]bool),
}
}*/
// if has primary key, return column
func (table *Table) PKColumns() []*Column {
columns := make([]*Column, 0)
for _, name := range table.PrimaryKeys {
columns = append(columns, table.Columns[strings.ToLower(name)])
}
return columns
}
func (table *Table) AutoIncrColumn() *Column {
return table.Columns[strings.ToLower(table.AutoIncrement)]
}
func (table *Table) VersionColumn() *Column {
return table.Columns[strings.ToLower(table.Version)]
}
// add a column to table
func (table *Table) AddColumn(col *Column) {
table.ColumnsSeq = append(table.ColumnsSeq, col.Name)
table.Columns[strings.ToLower(col.Name)] = col
if col.IsPrimaryKey {
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
}
if col.IsAutoIncrement {
table.AutoIncrement = col.Name
}
if col.IsCreated {
table.Created[col.Name] = true
}
if col.IsUpdated {
table.Updated = col.Name
}
if col.IsVersion {
table.Version = col.Name
}
}
// add an index or an unique to table
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
}
func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
colNames := make([]string, 0) colNames := make([]string, 0)
args := make([]interface{}, 0) args := make([]interface{}, 0)
for _, col := range table.Columns { for _, col := range table.Columns() {
lColName := strings.ToLower(col.Name)
if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated { if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[lColName]; !ok {
continue continue
} }
} }
if col.MapType == ONLYFROMDB { if col.MapType == core.ONLYFROMDB {
continue continue
} }
fieldValue := col.ValueOf(bean) fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
session.Engine.LogError(err)
continue
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement { if col.IsAutoIncrement {
switch fieldValue.Type().Kind() { switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
@ -439,12 +69,12 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc
} }
if session.Statement.ColumnStr != "" { if session.Statement.ColumnStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; !ok { if _, ok := session.Statement.columnMap[lColName]; !ok {
continue continue
} }
} }
if session.Statement.OmitStr != "" { if session.Statement.OmitStr != "" {
if _, ok := session.Statement.columnMap[col.Name]; ok { if _, ok := session.Statement.columnMap[lColName]; ok {
continue continue
} }
} }
@ -469,10 +99,3 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc
} }
return colNames, args, nil return colNames, args, nil
} }
// Conversion is an interface. A type implements Conversion will according
// the custom method to fill into database and retrieve from database.
type Conversion interface {
FromDB([]byte) error
ToDB() ([]byte, error)
}

View File

@ -1,4 +1,4 @@
--DROP DATABASE xorm_test; --DROP DATABASE xorm_test;
--DROP DATABASE xorm_test2; --DROP DATABASE xorm_test2;
CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET utf8 COLLATE utf8_general_ci;
CREATE DATABASE IF NOT EXISTS xorm_test2 CHARACTER SET utf8 COLLATE utf8_general_ci; CREATE DATABASE IF NOT EXISTS xorm_test2 CHARACTER SET utf8 COLLATE utf8_general_ci;

51
xorm.go
View File

@ -7,6 +7,10 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"sync" "sync"
"github.com/lunny/xorm/core"
_ "github.com/lunny/xorm/dialects"
_ "github.com/lunny/xorm/drivers"
) )
const ( const (
@ -20,39 +24,38 @@ func close(engine *Engine) {
// new a db manager according to the parameter. Currently support four // new a db manager according to the parameter. Currently support four
// drivers // drivers
func NewEngine(driverName string, dataSourceName string) (*Engine, error) { func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine := &Engine{DriverName: driverName, driver := core.QueryDriver(driverName)
DataSourceName: dataSourceName, Filters: make([]Filter, 0)} if driver == nil {
engine.SetMapper(SnakeMapper{})
if driverName == SQLITE {
engine.dialect = &sqlite3{}
} else if driverName == MYSQL {
engine.dialect = &mysql{}
} else if driverName == POSTGRES {
engine.dialect = &postgres{}
engine.Filters = append(engine.Filters, &PgSeqFilter{})
engine.Filters = append(engine.Filters, &QuoteFilter{})
} else if driverName == MYMYSQL {
engine.dialect = &mymysql{}
} else if driverName == "odbc" {
engine.dialect = &mssql{quoteFilter: &QuoteFilter{}}
engine.Filters = append(engine.Filters, &QuoteFilter{})
} else if driverName == ORACLE_OCI {
engine.dialect = &oracle{}
engine.Filters = append(engine.Filters, &QuoteFilter{})
} else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName)) return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
} }
err := engine.dialect.Init(driverName, dataSourceName)
uri, err := driver.Parse(driverName, dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
engine.Tables = make(map[reflect.Type]*Table) dialect := core.QueryDialect(uri.DbType)
if dialect == nil {
return nil, errors.New(fmt.Sprintf("Unsupported dialect type: %v", uri.DbType))
}
err = dialect.Init(uri, driverName, dataSourceName)
if err != nil {
return nil, err
}
engine := &Engine{DriverName: driverName,
DataSourceName: dataSourceName, dialect: dialect,
tableCachers: make(map[reflect.Type]Cacher)}
engine.SetMapper(SnakeMapper{})
engine.Filters = dialect.Filters()
engine.Tables = make(map[reflect.Type]*core.Table)
engine.mutex = &sync.Mutex{} engine.mutex = &sync.Mutex{}
engine.TagIdentifier = "xorm" engine.TagIdentifier = "xorm"
engine.Filters = append(engine.Filters, &IdFilter{})
engine.Logger = os.Stdout engine.Logger = os.Stdout
//engine.Pool = NewSimpleConnectPool() //engine.Pool = NewSimpleConnectPool()

View File

@ -1,2 +1,2 @@
[deps] [deps]
github.com/lunny/xorm=../ github.com/lunny/xorm=../

View File

@ -1,30 +1,30 @@
# xorm tools # xorm tools
xorm tools is a set of tools for database operation. xorm tools is a set of tools for database operation.
## Install ## Install
`go get github.com/lunny/xorm/xorm` `go get github.com/lunny/xorm/xorm`
and you should install the depends below: and you should install the depends below:
* github.com/lunny/xorm * github.com/lunny/xorm
* Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) * Mysql: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) * SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
* Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq) * Postgres: [github.com/bylevel/pq](https://github.com/bylevel/pq)
## Reverse ## Reverse
After you installed the tool, you can type After you installed the tool, you can type
`xorm help reverse` `xorm help reverse`
to get help to get help
@ -50,13 +50,13 @@ Now, xorm tool supports go and c++ two languages and have go, goxorm, c++ three
```` ````
lang=go lang=go
genJson=1 genJson=1
``` ```
lang must be go or c++ now. lang must be go or c++ now.
genJson can be 1 or 0, if 1 then the struct will have json tag. genJson can be 1 or 0, if 1 then the struct will have json tag.
## LICENSE ## LICENSE
BSD License BSD License
[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)

View File

@ -1,65 +1,65 @@
package main package main
import ( import (
//"fmt" //"fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"strings" "strings"
"text/template" "text/template"
) )
var ( var (
CPlusTmpl LangTmpl = LangTmpl{ CPlusTmpl LangTmpl = LangTmpl{
template.FuncMap{"Mapper": mapper.Table2Obj, template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": cPlusTypeStr, "Type": cPlusTypeStr,
"UnTitle": unTitle, "UnTitle": unTitle,
}, },
nil, nil,
genCPlusImports, genCPlusImports,
} }
) )
func cPlusTypeStr(col *xorm.Column) string { func cPlusTypeStr(col *xorm.Column) string {
tp := col.SQLType tp := col.SQLType
name := strings.ToUpper(tp.Name) name := strings.ToUpper(tp.Name)
switch name { switch name {
case xorm.Bit, xorm.TinyInt, xorm.SmallInt, xorm.MediumInt, xorm.Int, xorm.Integer, xorm.Serial: case xorm.Bit, xorm.TinyInt, xorm.SmallInt, xorm.MediumInt, xorm.Int, xorm.Integer, xorm.Serial:
return "int" return "int"
case xorm.BigInt, xorm.BigSerial: case xorm.BigInt, xorm.BigSerial:
return "__int64" return "__int64"
case xorm.Char, xorm.Varchar, xorm.TinyText, xorm.Text, xorm.MediumText, xorm.LongText: case xorm.Char, xorm.Varchar, xorm.TinyText, xorm.Text, xorm.MediumText, xorm.LongText:
return "tstring" return "tstring"
case xorm.Date, xorm.DateTime, xorm.Time, xorm.TimeStamp: case xorm.Date, xorm.DateTime, xorm.Time, xorm.TimeStamp:
return "time_t" return "time_t"
case xorm.Decimal, xorm.Numeric: case xorm.Decimal, xorm.Numeric:
return "tstring" return "tstring"
case xorm.Real, xorm.Float: case xorm.Real, xorm.Float:
return "float" return "float"
case xorm.Double: case xorm.Double:
return "double" return "double"
case xorm.TinyBlob, xorm.Blob, xorm.MediumBlob, xorm.LongBlob, xorm.Bytea: case xorm.TinyBlob, xorm.Blob, xorm.MediumBlob, xorm.LongBlob, xorm.Bytea:
return "tstring" return "tstring"
case xorm.Bool: case xorm.Bool:
return "bool" return "bool"
default: default:
return "tstring" return "tstring"
} }
return "" return ""
} }
func genCPlusImports(tables []*xorm.Table) map[string]string { func genCPlusImports(tables []*xorm.Table) map[string]string {
imports := make(map[string]string) imports := make(map[string]string)
for _, table := range tables { for _, table := range tables {
for _, col := range table.Columns { for _, col := range table.Columns {
switch cPlusTypeStr(col) { switch cPlusTypeStr(col) {
case "time_t": case "time_t":
imports[`<time.h>`] = `<time.h>` imports[`<time.h>`] = `<time.h>`
case "tstring": case "tstring":
imports["<string>"] = "<string>" imports["<string>"] = "<string>"
//case "__int64": //case "__int64":
// imports[""] = "" // imports[""] = ""
} }
} }
} }
return imports return imports
} }

View File

@ -1,78 +1,78 @@
package main package main
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
) )
// A Command is an implementation of a go command // A Command is an implementation of a go command
// like go build or go fix. // like go build or go fix.
type Command struct { type Command struct {
// Run runs the command. // Run runs the command.
// The args are the arguments after the command name. // The args are the arguments after the command name.
Run func(cmd *Command, args []string) Run func(cmd *Command, args []string)
// UsageLine is the one-line usage message. // UsageLine is the one-line usage message.
// The first word in the line is taken to be the command name. // The first word in the line is taken to be the command name.
UsageLine string UsageLine string
// Short is the short description shown in the 'go help' output. // Short is the short description shown in the 'go help' output.
Short string Short string
// Long is the long message shown in the 'go help <this-command>' output. // Long is the long message shown in the 'go help <this-command>' output.
Long string Long string
// Flag is a set of flags specific to this command. // Flag is a set of flags specific to this command.
Flags map[string]bool Flags map[string]bool
} }
// Name returns the command's name: the first word in the usage line. // Name returns the command's name: the first word in the usage line.
func (c *Command) Name() string { func (c *Command) Name() string {
name := c.UsageLine name := c.UsageLine
i := strings.Index(name, " ") i := strings.Index(name, " ")
if i >= 0 { if i >= 0 {
name = name[:i] name = name[:i]
} }
return name return name
} }
func (c *Command) Usage() { func (c *Command) Usage() {
fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine) fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine)
fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long)) fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long))
os.Exit(2) os.Exit(2)
} }
// Runnable reports whether the command can be run; otherwise // Runnable reports whether the command can be run; otherwise
// it is a documentation pseudo-command such as importpath. // it is a documentation pseudo-command such as importpath.
func (c *Command) Runnable() bool { func (c *Command) Runnable() bool {
return c.Run != nil return c.Run != nil
} }
// checkFlags checks if the flag exists with correct format. // checkFlags checks if the flag exists with correct format.
func checkFlags(flags map[string]bool, args []string, print func(string)) int { func checkFlags(flags map[string]bool, args []string, print func(string)) int {
num := 0 // Number of valid flags, use to cut out. num := 0 // Number of valid flags, use to cut out.
for i, f := range args { for i, f := range args {
// Check flag prefix '-'. // Check flag prefix '-'.
if !strings.HasPrefix(f, "-") { if !strings.HasPrefix(f, "-") {
// Not a flag, finish check process. // Not a flag, finish check process.
break break
} }
// Check if it a valid flag. // Check if it a valid flag.
if v, ok := flags[f]; ok { if v, ok := flags[f]; ok {
flags[f] = !v flags[f] = !v
if !v { if !v {
print(f) print(f)
} else { } else {
fmt.Println("DISABLE: " + f) fmt.Println("DISABLE: " + f)
} }
} else { } else {
fmt.Printf("[ERRO] Unknown flag: %s.\n", f) fmt.Printf("[ERRO] Unknown flag: %s.\n", f)
return -1 return -1
} }
num = i + 1 num = i + 1
} }
return num return num
} }

View File

@ -1,263 +1,263 @@
package main package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"go/format" "go/format"
"reflect" "reflect"
"strings" "strings"
"text/template" "text/template"
) )
var ( var (
GoLangTmpl LangTmpl = LangTmpl{ GoLangTmpl LangTmpl = LangTmpl{
template.FuncMap{"Mapper": mapper.Table2Obj, template.FuncMap{"Mapper": mapper.Table2Obj,
"Type": typestring, "Type": typestring,
"Tag": tag, "Tag": tag,
"UnTitle": unTitle, "UnTitle": unTitle,
"gt": gt, "gt": gt,
"getCol": getCol, "getCol": getCol,
}, },
formatGo, formatGo,
genGoImports, genGoImports,
} }
) )
var ( var (
errBadComparisonType = errors.New("invalid type for comparison") errBadComparisonType = errors.New("invalid type for comparison")
errBadComparison = errors.New("incompatible types for comparison") errBadComparison = errors.New("incompatible types for comparison")
errNoComparison = errors.New("missing argument for comparison") errNoComparison = errors.New("missing argument for comparison")
) )
type kind int type kind int
const ( const (
invalidKind kind = iota invalidKind kind = iota
boolKind boolKind
complexKind complexKind
intKind intKind
floatKind floatKind
integerKind integerKind
stringKind stringKind
uintKind uintKind
) )
func basicKind(v reflect.Value) (kind, error) { func basicKind(v reflect.Value) (kind, error) {
switch v.Kind() { switch v.Kind() {
case reflect.Bool: case reflect.Bool:
return boolKind, nil return boolKind, nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return intKind, nil return intKind, nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return uintKind, nil return uintKind, nil
case reflect.Float32, reflect.Float64: case reflect.Float32, reflect.Float64:
return floatKind, nil return floatKind, nil
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
return complexKind, nil return complexKind, nil
case reflect.String: case reflect.String:
return stringKind, nil return stringKind, nil
} }
return invalidKind, errBadComparisonType return invalidKind, errBadComparisonType
} }
// eq evaluates the comparison a == b || a == c || ... // eq evaluates the comparison a == b || a == c || ...
func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) { func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1) v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1) k1, err := basicKind(v1)
if err != nil { if err != nil {
return false, err return false, err
} }
if len(arg2) == 0 { if len(arg2) == 0 {
return false, errNoComparison return false, errNoComparison
} }
for _, arg := range arg2 { for _, arg := range arg2 {
v2 := reflect.ValueOf(arg) v2 := reflect.ValueOf(arg)
k2, err := basicKind(v2) k2, err := basicKind(v2)
if err != nil { if err != nil {
return false, err return false, err
} }
if k1 != k2 { if k1 != k2 {
return false, errBadComparison return false, errBadComparison
} }
truth := false truth := false
switch k1 { switch k1 {
case boolKind: case boolKind:
truth = v1.Bool() == v2.Bool() truth = v1.Bool() == v2.Bool()
case complexKind: case complexKind:
truth = v1.Complex() == v2.Complex() truth = v1.Complex() == v2.Complex()
case floatKind: case floatKind:
truth = v1.Float() == v2.Float() truth = v1.Float() == v2.Float()
case intKind: case intKind:
truth = v1.Int() == v2.Int() truth = v1.Int() == v2.Int()
case stringKind: case stringKind:
truth = v1.String() == v2.String() truth = v1.String() == v2.String()
case uintKind: case uintKind:
truth = v1.Uint() == v2.Uint() truth = v1.Uint() == v2.Uint()
default: default:
panic("invalid kind") panic("invalid kind")
} }
if truth { if truth {
return true, nil return true, nil
} }
} }
return false, nil return false, nil
} }
// lt evaluates the comparison a < b. // lt evaluates the comparison a < b.
func lt(arg1, arg2 interface{}) (bool, error) { func lt(arg1, arg2 interface{}) (bool, error) {
v1 := reflect.ValueOf(arg1) v1 := reflect.ValueOf(arg1)
k1, err := basicKind(v1) k1, err := basicKind(v1)
if err != nil { if err != nil {
return false, err return false, err
} }
v2 := reflect.ValueOf(arg2) v2 := reflect.ValueOf(arg2)
k2, err := basicKind(v2) k2, err := basicKind(v2)
if err != nil { if err != nil {
return false, err return false, err
} }
if k1 != k2 { if k1 != k2 {
return false, errBadComparison return false, errBadComparison
} }
truth := false truth := false
switch k1 { switch k1 {
case boolKind, complexKind: case boolKind, complexKind:
return false, errBadComparisonType return false, errBadComparisonType
case floatKind: case floatKind:
truth = v1.Float() < v2.Float() truth = v1.Float() < v2.Float()
case intKind: case intKind:
truth = v1.Int() < v2.Int() truth = v1.Int() < v2.Int()
case stringKind: case stringKind:
truth = v1.String() < v2.String() truth = v1.String() < v2.String()
case uintKind: case uintKind:
truth = v1.Uint() < v2.Uint() truth = v1.Uint() < v2.Uint()
default: default:
panic("invalid kind") panic("invalid kind")
} }
return truth, nil return truth, nil
} }
// le evaluates the comparison <= b. // le evaluates the comparison <= b.
func le(arg1, arg2 interface{}) (bool, error) { func le(arg1, arg2 interface{}) (bool, error) {
// <= is < or ==. // <= is < or ==.
lessThan, err := lt(arg1, arg2) lessThan, err := lt(arg1, arg2)
if lessThan || err != nil { if lessThan || err != nil {
return lessThan, err return lessThan, err
} }
return eq(arg1, arg2) return eq(arg1, arg2)
} }
// gt evaluates the comparison a > b. // gt evaluates the comparison a > b.
func gt(arg1, arg2 interface{}) (bool, error) { func gt(arg1, arg2 interface{}) (bool, error) {
// > is the inverse of <=. // > is the inverse of <=.
lessOrEqual, err := le(arg1, arg2) lessOrEqual, err := le(arg1, arg2)
if err != nil { if err != nil {
return false, err return false, err
} }
return !lessOrEqual, nil return !lessOrEqual, nil
} }
func getCol(cols map[string]*xorm.Column, name string) *xorm.Column { func getCol(cols map[string]*xorm.Column, name string) *xorm.Column {
return cols[name] return cols[name]
} }
func formatGo(src string) (string, error) { func formatGo(src string) (string, error) {
source, err := format.Source([]byte(src)) source, err := format.Source([]byte(src))
if err != nil { if err != nil {
return "", err return "", err
} }
return string(source), nil return string(source), nil
} }
func genGoImports(tables []*xorm.Table) map[string]string { func genGoImports(tables []*xorm.Table) map[string]string {
imports := make(map[string]string) imports := make(map[string]string)
for _, table := range tables { for _, table := range tables {
for _, col := range table.Columns { for _, col := range table.Columns {
if typestring(col) == "time.Time" { if typestring(col) == "time.Time" {
imports["time"] = "time" imports["time"] = "time"
} }
} }
} }
return imports return imports
} }
func typestring(col *xorm.Column) string { func typestring(col *xorm.Column) string {
st := col.SQLType st := col.SQLType
/*if col.IsPrimaryKey { /*if col.IsPrimaryKey {
return "int64" return "int64"
}*/ }*/
t := xorm.SQLType2Type(st) t := xorm.SQLType2Type(st)
s := t.String() s := t.String()
if s == "[]uint8" { if s == "[]uint8" {
return "[]byte" return "[]byte"
} }
return s return s
} }
func tag(table *xorm.Table, col *xorm.Column) string { func tag(table *xorm.Table, col *xorm.Column) string {
isNameId := (mapper.Table2Obj(col.Name) == "Id") isNameId := (mapper.Table2Obj(col.Name) == "Id")
isIdPk := isNameId && typestring(col) == "int64" isIdPk := isNameId && typestring(col) == "int64"
res := make([]string, 0) res := make([]string, 0)
if !col.Nullable { if !col.Nullable {
if !isIdPk { if !isIdPk {
res = append(res, "not null") res = append(res, "not null")
} }
} }
if col.IsPrimaryKey { if col.IsPrimaryKey {
if !isIdPk { if !isIdPk {
res = append(res, "pk") res = append(res, "pk")
} }
} }
if col.Default != "" { if col.Default != "" {
res = append(res, "default "+col.Default) res = append(res, "default "+col.Default)
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
if !isIdPk { if !isIdPk {
res = append(res, "autoincr") res = append(res, "autoincr")
} }
} }
if col.IsCreated { if col.IsCreated {
res = append(res, "created") res = append(res, "created")
} }
if col.IsUpdated { if col.IsUpdated {
res = append(res, "updated") res = append(res, "updated")
} }
for name, _ := range col.Indexes { for name, _ := range col.Indexes {
index := table.Indexes[name] index := table.Indexes[name]
var uistr string var uistr string
if index.Type == xorm.UniqueType { if index.Type == xorm.UniqueType {
uistr = "unique" uistr = "unique"
} else if index.Type == xorm.IndexType { } else if index.Type == xorm.IndexType {
uistr = "index" uistr = "index"
} }
if len(index.Cols) > 1 { if len(index.Cols) > 1 {
uistr += "(" + index.Name + ")" uistr += "(" + index.Name + ")"
} }
res = append(res, uistr) res = append(res, uistr)
} }
nstr := col.SQLType.Name nstr := col.SQLType.Name
if col.Length != 0 { if col.Length != 0 {
if col.Length2 != 0 { if col.Length2 != 0 {
nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2) nstr += fmt.Sprintf("(%v, %v)", col.Length, col.Length2)
} else { } else {
nstr += fmt.Sprintf("(%v)", col.Length) nstr += fmt.Sprintf("(%v)", col.Length)
} }
} }
res = append(res, nstr) res = append(res, nstr)
var tags []string var tags []string
if genJson { if genJson {
tags = append(tags, "json:\""+col.Name+"\"") tags = append(tags, "json:\""+col.Name+"\"")
} }
if len(res) > 0 { if len(res) > 0 {
tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"") tags = append(tags, "xorm:\""+strings.Join(res, " ")+"\"")
} }
if len(tags) > 0 { if len(tags) > 0 {
return "`" + strings.Join(tags, " ") + "`" return "`" + strings.Join(tags, " ") + "`"
} else { } else {
return "" return ""
} }
} }

View File

@ -1,51 +1,51 @@
package main package main
import ( import (
"github.com/lunny/xorm" "github.com/lunny/xorm"
"io/ioutil" "io/ioutil"
"strings" "strings"
"text/template" "text/template"
) )
type LangTmpl struct { type LangTmpl struct {
Funcs template.FuncMap Funcs template.FuncMap
Formater func(string) (string, error) Formater func(string) (string, error)
GenImports func([]*xorm.Table) map[string]string GenImports func([]*xorm.Table) map[string]string
} }
var ( var (
mapper = &xorm.SnakeMapper{} mapper = &xorm.SnakeMapper{}
langTmpls = map[string]LangTmpl{ langTmpls = map[string]LangTmpl{
"go": GoLangTmpl, "go": GoLangTmpl,
"c++": CPlusTmpl, "c++": CPlusTmpl,
} }
) )
func loadConfig(f string) map[string]string { func loadConfig(f string) map[string]string {
bts, err := ioutil.ReadFile(f) bts, err := ioutil.ReadFile(f)
if err != nil { if err != nil {
return nil return nil
} }
configs := make(map[string]string) configs := make(map[string]string)
lines := strings.Split(string(bts), "\n") lines := strings.Split(string(bts), "\n")
for _, line := range lines { for _, line := range lines {
line = strings.TrimRight(line, "\r") line = strings.TrimRight(line, "\r")
vs := strings.Split(line, "=") vs := strings.Split(line, "=")
if len(vs) == 2 { if len(vs) == 2 {
configs[strings.TrimSpace(vs[0])] = strings.TrimSpace(vs[1]) configs[strings.TrimSpace(vs[0])] = strings.TrimSpace(vs[1])
} }
} }
return configs return configs
} }
func unTitle(src string) string { func unTitle(src string) string {
if src == "" { if src == "" {
return "" return ""
} }
if len(src) == 1 { if len(src) == 1 {
return strings.ToLower(string(src[0])) return strings.ToLower(string(src[0]))
} else { } else {
return strings.ToLower(string(src[0])) + src[1:] return strings.ToLower(string(src[0])) + src[1:]
} }
} }

View File

@ -1,26 +1,26 @@
package main package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
_ "github.com/bylevel/pq" _ "github.com/bylevel/pq"
"github.com/dvirsky/go-pylog/logging" "github.com/dvirsky/go-pylog/logging"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/lunny/xorm" "github.com/lunny/xorm"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv" _ "github.com/ziutek/mymysql/godrv"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"strconv" "strconv"
"text/template" "text/template"
) )
var CmdReverse = &Command{ var CmdReverse = &Command{
UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]", UsageLine: "reverse [-m] driverName datasourceName tmplPath [generatedPath]",
Short: "reverse a db to codes", Short: "reverse a db to codes",
Long: ` Long: `
according database's tables and columns to generate codes for Go, C++ and etc. according database's tables and columns to generate codes for Go, C++ and etc.
-m Generated one go file for every table -m Generated one go file for every table
@ -33,236 +33,236 @@ according database's tables and columns to generate codes for Go, C++ and etc.
} }
func init() { func init() {
CmdReverse.Run = runReverse CmdReverse.Run = runReverse
CmdReverse.Flags = map[string]bool{ CmdReverse.Flags = map[string]bool{
"-s": false, "-s": false,
"-l": false, "-l": false,
} }
} }
var ( var (
genJson bool = false genJson bool = false
) )
func printReversePrompt(flag string) { func printReversePrompt(flag string) {
} }
type Tmpl struct { type Tmpl struct {
Tables []*xorm.Table Tables []*xorm.Table
Imports map[string]string Imports map[string]string
Model string Model string
} }
func dirExists(dir string) bool { func dirExists(dir string) bool {
d, e := os.Stat(dir) d, e := os.Stat(dir)
switch { switch {
case e != nil: case e != nil:
return false return false
case !d.IsDir(): case !d.IsDir():
return false return false
} }
return true return true
} }
func runReverse(cmd *Command, args []string) { func runReverse(cmd *Command, args []string) {
num := checkFlags(cmd.Flags, args, printReversePrompt) num := checkFlags(cmd.Flags, args, printReversePrompt)
if num == -1 { if num == -1 {
return return
} }
args = args[num:] args = args[num:]
if len(args) < 3 { if len(args) < 3 {
fmt.Println("params error, please see xorm help reverse") fmt.Println("params error, please see xorm help reverse")
return return
} }
var isMultiFile bool = true var isMultiFile bool = true
if use, ok := cmd.Flags["-s"]; ok { if use, ok := cmd.Flags["-s"]; ok {
isMultiFile = !use isMultiFile = !use
} }
curPath, err := os.Getwd() curPath, err := os.Getwd()
if err != nil { if err != nil {
fmt.Println(curPath) fmt.Println(curPath)
return return
} }
var genDir string var genDir string
var model string var model string
if len(args) == 4 { if len(args) == 4 {
genDir, err = filepath.Abs(args[3]) genDir, err = filepath.Abs(args[3])
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
model = path.Base(genDir) model = path.Base(genDir)
} else { } else {
model = "model" model = "model"
genDir = path.Join(curPath, model) genDir = path.Join(curPath, model)
} }
dir, err := filepath.Abs(args[2]) dir, err := filepath.Abs(args[2])
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
if !dirExists(dir) { if !dirExists(dir) {
logging.Error("Template %v path is not exist", dir) logging.Error("Template %v path is not exist", dir)
return return
} }
var langTmpl LangTmpl var langTmpl LangTmpl
var ok bool var ok bool
var lang string = "go" var lang string = "go"
cfgPath := path.Join(dir, "config") cfgPath := path.Join(dir, "config")
info, err := os.Stat(cfgPath) info, err := os.Stat(cfgPath)
var configs map[string]string var configs map[string]string
if err == nil && !info.IsDir() { if err == nil && !info.IsDir() {
configs = loadConfig(cfgPath) configs = loadConfig(cfgPath)
if l, ok := configs["lang"]; ok { if l, ok := configs["lang"]; ok {
lang = l lang = l
} }
if j, ok := configs["genJson"]; ok { if j, ok := configs["genJson"]; ok {
genJson, err = strconv.ParseBool(j) genJson, err = strconv.ParseBool(j)
} }
} }
if langTmpl, ok = langTmpls[lang]; !ok { if langTmpl, ok = langTmpls[lang]; !ok {
fmt.Println("Unsupported programing language", lang) fmt.Println("Unsupported programing language", lang)
return return
} }
os.MkdirAll(genDir, os.ModePerm) os.MkdirAll(genDir, os.ModePerm)
Orm, err := xorm.NewEngine(args[0], args[1]) Orm, err := xorm.NewEngine(args[0], args[1])
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
tables, err := Orm.DBMetas() tables, err := Orm.DBMetas()
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return return
} }
filepath.Walk(dir, func(f string, info os.FileInfo, err error) error { filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
if info.IsDir() { if info.IsDir() {
return nil return nil
} }
if info.Name() == "config" { if info.Name() == "config" {
return nil return nil
} }
bs, err := ioutil.ReadFile(f) bs, err := ioutil.ReadFile(f)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
t := template.New(f) t := template.New(f)
t.Funcs(langTmpl.Funcs) t.Funcs(langTmpl.Funcs)
tmpl, err := t.Parse(string(bs)) tmpl, err := t.Parse(string(bs))
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var w *os.File var w *os.File
fileName := info.Name() fileName := info.Name()
newFileName := fileName[:len(fileName)-4] newFileName := fileName[:len(fileName)-4]
ext := path.Ext(newFileName) ext := path.Ext(newFileName)
if !isMultiFile { if !isMultiFile {
w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600) w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0600)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
imports := langTmpl.GenImports(tables) imports := langTmpl.GenImports(tables)
tbls := make([]*xorm.Table, 0) tbls := make([]*xorm.Table, 0)
for _, table := range tables { for _, table := range tables {
tbls = append(tbls, table) tbls = append(tbls, table)
} }
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
t := &Tmpl{Tables: tbls, Imports: imports, Model: model} t := &Tmpl{Tables: tbls, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t) err = tmpl.Execute(newbytes, t)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
tplcontent, err := ioutil.ReadAll(newbytes) tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var source string var source string
if langTmpl.Formater != nil { if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent)) source, err = langTmpl.Formater(string(tplcontent))
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
} else { } else {
source = string(tplcontent) source = string(tplcontent)
} }
w.WriteString(source) w.WriteString(source)
w.Close() w.Close()
} else { } else {
for _, table := range tables { for _, table := range tables {
// imports // imports
tbs := []*xorm.Table{table} tbs := []*xorm.Table{table}
imports := langTmpl.GenImports(tbs) imports := langTmpl.GenImports(tbs)
w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600) w, err := os.OpenFile(path.Join(genDir, unTitle(mapper.Table2Obj(table.Name))+ext), os.O_RDWR|os.O_CREATE, 0600)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
newbytes := bytes.NewBufferString("") newbytes := bytes.NewBufferString("")
t := &Tmpl{Tables: tbs, Imports: imports, Model: model} t := &Tmpl{Tables: tbs, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t) err = tmpl.Execute(newbytes, t)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
tplcontent, err := ioutil.ReadAll(newbytes) tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil { if err != nil {
logging.Error("%v", err) logging.Error("%v", err)
return err return err
} }
var source string var source string
if langTmpl.Formater != nil { if langTmpl.Formater != nil {
source, err = langTmpl.Formater(string(tplcontent)) source, err = langTmpl.Formater(string(tplcontent))
if err != nil { if err != nil {
logging.Error("%v-%v", err, string(tplcontent)) logging.Error("%v-%v", err, string(tplcontent))
return err return err
} }
} else { } else {
source = string(tplcontent) source = string(tplcontent)
} }
w.WriteString(source) w.WriteString(source)
w.Close() w.Close()
} }
} }
return nil return nil
}) })
} }

View File

@ -1,15 +1,15 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/lunny/xorm" "github.com/lunny/xorm"
"strings" "strings"
) )
var CmdShell = &Command{ var CmdShell = &Command{
UsageLine: "shell driverName datasourceName", UsageLine: "shell driverName datasourceName",
Short: "a general shell to operate all kinds of database", Short: "a general shell to operate all kinds of database",
Long: ` Long: `
general database's shell for sqlite3, mysql, postgres. general database's shell for sqlite3, mysql, postgres.
driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres driverName Database driver name, now supported four: mysql mymysql sqlite3 postgres
@ -18,14 +18,14 @@ general database's shell for sqlite3, mysql, postgres.
} }
func init() { func init() {
CmdShell.Run = runShell CmdShell.Run = runShell
CmdShell.Flags = map[string]bool{} CmdShell.Flags = map[string]bool{}
} }
var engine *xorm.Engine var engine *xorm.Engine
func shellHelp() { func shellHelp() {
fmt.Println(` fmt.Println(`
show tables show all tables show tables show all tables
columns <table_name> show table's column info columns <table_name> show table's column info
indexes <table_name> show table's index info indexes <table_name> show table's index info
@ -38,110 +38,110 @@ func shellHelp() {
} }
func runShell(cmd *Command, args []string) { func runShell(cmd *Command, args []string) {
if len(args) != 2 { if len(args) != 2 {
fmt.Println("params error, please see xorm help shell") fmt.Println("params error, please see xorm help shell")
return return
} }
var err error var err error
engine, err = xorm.NewEngine(args[0], args[1]) engine, err = xorm.NewEngine(args[0], args[1])
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
err = engine.Ping() err = engine.Ping()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
return return
} }
var scmd string var scmd string
fmt.Print("xorm$ ") fmt.Print("xorm$ ")
for { for {
var input string var input string
_, err := fmt.Scan(&input) _, err := fmt.Scan(&input)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
continue continue
} }
if strings.ToLower(input) == "exit" { if strings.ToLower(input) == "exit" {
fmt.Println("bye") fmt.Println("bye")
return return
} }
if !strings.HasSuffix(input, ";") { if !strings.HasSuffix(input, ";") {
scmd = scmd + " " + input scmd = scmd + " " + input
continue continue
} }
scmd = scmd + " " + input scmd = scmd + " " + input
lcmd := strings.TrimSpace(strings.ToLower(scmd)) lcmd := strings.TrimSpace(strings.ToLower(scmd))
if strings.HasPrefix(lcmd, "select") { if strings.HasPrefix(lcmd, "select") {
res, err := engine.Query(scmd + "\n") res, err := engine.Query(scmd + "\n")
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
if len(res) <= 0 { if len(res) <= 0 {
fmt.Println("no records") fmt.Println("no records")
} else { } else {
columns := make(map[string]int) columns := make(map[string]int)
for k, _ := range res[0] { for k, _ := range res[0] {
columns[k] = len(k) columns[k] = len(k)
} }
for _, m := range res { for _, m := range res {
for k, s := range m { for k, s := range m {
l := len(string(s)) l := len(string(s))
if l > columns[k] { if l > columns[k] {
columns[k] = l columns[k] = l
} }
} }
} }
var maxlen = 0 var maxlen = 0
for _, l := range columns { for _, l := range columns {
maxlen = maxlen + l + 3 maxlen = maxlen + l + 3
} }
maxlen = maxlen + 1 maxlen = maxlen + 1
fmt.Println(strings.Repeat("-", maxlen)) fmt.Println(strings.Repeat("-", maxlen))
fmt.Print("|") fmt.Print("|")
slice := make([]string, 0) slice := make([]string, 0)
for k, l := range columns { for k, l := range columns {
fmt.Print(" " + k + " ") fmt.Print(" " + k + " ")
fmt.Print(strings.Repeat(" ", l-len(k))) fmt.Print(strings.Repeat(" ", l-len(k)))
fmt.Print("|") fmt.Print("|")
slice = append(slice, k) slice = append(slice, k)
} }
fmt.Print("\n") fmt.Print("\n")
for _, r := range res { for _, r := range res {
fmt.Print("|") fmt.Print("|")
for _, k := range slice { for _, k := range slice {
fmt.Print(" " + string(r[k]) + " ") fmt.Print(" " + string(r[k]) + " ")
fmt.Print(strings.Repeat(" ", columns[k]-len(string(r[k])))) fmt.Print(strings.Repeat(" ", columns[k]-len(string(r[k]))))
fmt.Print("|") fmt.Print("|")
} }
fmt.Print("\n") fmt.Print("\n")
} }
fmt.Println(strings.Repeat("-", maxlen)) fmt.Println(strings.Repeat("-", maxlen))
//fmt.Println(res) //fmt.Println(res)
} }
} }
} else if lcmd == "show tables;" { } else if lcmd == "show tables;" {
/*tables, err := engine.DBMetas() /*tables, err := engine.DBMetas()
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
}*/ }*/
} else { } else {
cnt, err := engine.Exec(scmd) cnt, err := engine.Exec(scmd)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} else { } else {
fmt.Printf("%d records changed.\n", cnt) fmt.Printf("%d records changed.\n", cnt)
} }
} }
scmd = "" scmd = ""
fmt.Print("xorm$ ") fmt.Print("xorm$ ")
} }
} }

View File

@ -1,16 +1,16 @@
package main package main
import ( import (
"fmt" "fmt"
"github.com/dvirsky/go-pylog/logging" "github.com/dvirsky/go-pylog/logging"
"io" "io"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"text/template" "text/template"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
) )
// +build go1.1 // +build go1.1
@ -23,52 +23,52 @@ const version = "0.1"
// Commands lists the available commands and help topics. // Commands lists the available commands and help topics.
// The order here is the order in which they are printed by 'gopm help'. // The order here is the order in which they are printed by 'gopm help'.
var commands = []*Command{ var commands = []*Command{
CmdReverse, CmdReverse,
CmdShell, CmdShell,
} }
func init() { func init() {
runtime.GOMAXPROCS(runtime.NumCPU()) runtime.GOMAXPROCS(runtime.NumCPU())
} }
func main() { func main() {
logging.SetLevel(logging.ALL) logging.SetLevel(logging.ALL)
// Check length of arguments. // Check length of arguments.
args := os.Args[1:] args := os.Args[1:]
if len(args) < 1 { if len(args) < 1 {
usage() usage()
return return
} }
// Show help documentation. // Show help documentation.
if args[0] == "help" { if args[0] == "help" {
help(args[1:]) help(args[1:])
return return
} }
// Check commands and run. // Check commands and run.
for _, comm := range commands { for _, comm := range commands {
if comm.Name() == args[0] && comm.Run != nil { if comm.Name() == args[0] && comm.Run != nil {
comm.Run(comm, args[1:]) comm.Run(comm, args[1:])
exit() exit()
return return
} }
} }
fmt.Fprintf(os.Stderr, "xorm: unknown subcommand %q\nRun 'xorm help' for usage.\n", args[0]) fmt.Fprintf(os.Stderr, "xorm: unknown subcommand %q\nRun 'xorm help' for usage.\n", args[0])
setExitStatus(2) setExitStatus(2)
exit() exit()
} }
var exitStatus = 0 var exitStatus = 0
var exitMu sync.Mutex var exitMu sync.Mutex
func setExitStatus(n int) { func setExitStatus(n int) {
exitMu.Lock() exitMu.Lock()
if exitStatus < n { if exitStatus < n {
exitStatus = n exitStatus = n
} }
exitMu.Unlock() exitMu.Unlock()
} }
var usageTemplate = `xorm is a database tool based xorm package. var usageTemplate = `xorm is a database tool based xorm package.
@ -97,66 +97,66 @@ var helpTemplate = `{{if .Runnable}}usage: xorm {{.UsageLine}}
// tmpl executes the given template text on data, writing the result to w. // tmpl executes the given template text on data, writing the result to w.
func tmpl(w io.Writer, text string, data interface{}) { func tmpl(w io.Writer, text string, data interface{}) {
t := template.New("top") t := template.New("top")
t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize}) t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize})
template.Must(t.Parse(text)) template.Must(t.Parse(text))
if err := t.Execute(w, data); err != nil { if err := t.Execute(w, data); err != nil {
panic(err) panic(err)
} }
} }
func capitalize(s string) string { func capitalize(s string) string {
if s == "" { if s == "" {
return s return s
} }
r, n := utf8.DecodeRuneInString(s) r, n := utf8.DecodeRuneInString(s)
return string(unicode.ToTitle(r)) + s[n:] return string(unicode.ToTitle(r)) + s[n:]
} }
func printUsage(w io.Writer) { func printUsage(w io.Writer) {
tmpl(w, usageTemplate, commands) tmpl(w, usageTemplate, commands)
} }
func usage() { func usage() {
printUsage(os.Stderr) printUsage(os.Stderr)
os.Exit(2) os.Exit(2)
} }
// help implements the 'help' command. // help implements the 'help' command.
func help(args []string) { func help(args []string) {
if len(args) == 0 { if len(args) == 0 {
printUsage(os.Stdout) printUsage(os.Stdout)
// not exit 2: succeeded at 'gopm help'. // not exit 2: succeeded at 'gopm help'.
return return
} }
if len(args) != 1 { if len(args) != 1 {
fmt.Fprintf(os.Stderr, "usage: xorm help command\n\nToo many arguments given.\n") fmt.Fprintf(os.Stderr, "usage: xorm help command\n\nToo many arguments given.\n")
os.Exit(2) // failed at 'gopm help' os.Exit(2) // failed at 'gopm help'
} }
arg := args[0] arg := args[0]
for _, cmd := range commands { for _, cmd := range commands {
if cmd.Name() == arg { if cmd.Name() == arg {
tmpl(os.Stdout, helpTemplate, cmd) tmpl(os.Stdout, helpTemplate, cmd)
// not exit 2: succeeded at 'gopm help cmd'. // not exit 2: succeeded at 'gopm help cmd'.
return return
} }
} }
fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'xorm help'.\n", arg) fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'xorm help'.\n", arg)
os.Exit(2) // failed at 'gopm help cmd' os.Exit(2) // failed at 'gopm help cmd'
} }
var atexitFuncs []func() var atexitFuncs []func()
func atexit(f func()) { func atexit(f func()) {
atexitFuncs = append(atexitFuncs, f) atexitFuncs = append(atexitFuncs, f)
} }
func exit() { func exit() {
for _, f := range atexitFuncs { for _, f := range atexitFuncs {
f() f()
} }
os.Exit(exitStatus) os.Exit(exitStatus)
} }