fix mysql test and sqlite test

This commit is contained in:
Lunny Xiao 2013-05-06 16:01:17 +08:00
parent 95927253d4
commit 74e0e3b175
7 changed files with 580 additions and 392 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.o
*.a
*.so
*.db
# Folders
_obj

148
README.md
View File

@ -1,5 +1,5 @@
xorm
=====
# xorm
===========
[中文](README_CN.md)
@ -11,104 +11,105 @@ All in all, it's not entirely ready for advanced use yet, but it's getting there
Drivers for Go's sql package which support database/sql includes:
Mysql:[github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
Mysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
SQLite:[github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
### Installing xorm
go get github.com/lunny/xorm
## Installing xorm
### Quick Start
go get github.com/lunny/xorm
1.Create an database engine (for example: mysql)
## Quick Start
1.Create a database engine (for example: mysql)
engine := xorm.Create("mysql://root:123@localhost/test")
```go
engine := xorm.Create("mysql://root:123@localhost/test")
```
2.Define your struct
```go
type User struct {
Id int
Name string
Age int `xorm:"-"`
}
```
type User struct {
Id int
Name string
Age int `xorm:"-"`
}
for Simple Task, just use engine's functions:
begin start, you should create a database and then we create the tables
before beginning, you should create a database in mysql and then we will create the tables.
err := engine.CreateTables(&User{})
```go
err := engine.CreateTables(&User{})
```
then, insert an struct to table
```go
id, err := engine.Insert(&User{Name:"lunny"})
```
id, err := engine.Insert(&User{Name:"lunny"})
or you want to update this struct
```go
user := User{Id:1, Name:"xlw"}
rows, err := engine.Update(&user)
```
user := User{Id:1, Name:"xlw"}
rows, err := engine.Update(&user)
3.Fetch a single object by user
```go
var user = User{Id:27}
engine.Get(&user)
var user = User{Name:"xlw"}
engine.Get(&user)
```
var user = User{Id:27}
engine.Get(&user)
var user = User{Name:"xlw"}
engine.Get(&user)
##Deep Use
for deep use, you should create a session, this func will create a connection to db
```go
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return
}
```
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return
}
1.Fetch a single object by where
```go
var user Userinfo
session.Where("id=?", 27).Get(&user)
var user2 Userinfo
session.Where(3).Get(&user2) // this is shorthand for the version above
var user Userinfo
session.Where("id=?", 27).Get(&user)
var user3 Userinfo
session.Where("name = ?", "john").Get(&user3) // more complex query
var user2 Userinfo
session.Where(3).Get(&user2) // this is shorthand for the version above
var user3 Userinfo
session.Where("name = ?", "john").Get(&user3) // more complex query
var user4 Userinfo
session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex
var user4 Userinfo
session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex
```
2.Fetch multiple objects
```go
var allusers []Userinfo
err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20
var tenusers []Userinfo
err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0
var allusers []Userinfo
err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20
var everyone []Userinfo
err := session.Find(&everyone)
```
var tenusers []Userinfo
err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0
###***About Map Rules***
var everyone []Userinfo
err := session.Find(&everyone)
##***Mapping Rules***
1.Struct and struct's fields name should be Pascal style, and the table and column's name default is us
for example:
The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname.
@ -117,13 +118,22 @@ If the keyname is 'UserName' will turn into the select colum 'user_name'
2.You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper
another is use field tag, field tag support the below keywords:
[name] column name
pk the field is a primary key
int(11)/varchar(50) column type
autoincr auto incrment
[not ]null if column can be null value
unique unique
- this field is not map as a table column
* [name] column name
* pk the field is a primary key
* int(11)/varchar(50) column type
* autoincr auto incrment
* [not ]null if column can be null value
* unique unique
* \- this field is not map as a table column
##FAQ
1.How the xorm tag use both with json?
use space
type User struct {
User string `json:"user" orm:"user_id"`
}
## LICENSE

227
engine.go
View File

@ -6,22 +6,6 @@ import (
"reflect"
"strconv"
"strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
}
var (
Int = SQLType{"int", 11}
Char = SQLType{"char", 1}
Varchar = SQLType{"varchar", 50}
Date = SQLType{"date", 24}
Decimal = SQLType{"decimal", 26}
Float = SQLType{"float", 31}
Double = SQLType{"double", 31}
)
const (
@ -32,51 +16,6 @@ const (
MYMYSQL = "mymysql"
)
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
AutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}
type Engine struct {
Mapper IMapper
Protocol string
@ -91,21 +30,27 @@ type Engine struct {
AutoIncrement string
ShowSQL bool
QuoteIdentifier string
Statement Statement
}
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
}
func (e *Engine) OpenDB() (db *sql.DB, err error) {
db = nil
err = nil
if e.Protocol == "sqlite" {
if e.Protocol == SQLITE {
// 'sqlite:///foo.db'
db, err = sql.Open("sqlite3", e.Others)
// 'sqlite:///:memory:'
} else if e.Protocol == "mysql" {
} else if e.Protocol == MYSQL {
// 'mysql://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
connstr := strings.Join([]string{e.UserName, ":",
e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "")
db, err = sql.Open(e.Protocol, connstr)
} else if e.Protocol == "mymysql" {
} else if e.Protocol == MYMYSQL {
// DBNAME/USER/PASSWD
connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/")
db, err = sql.Open(e.Protocol, connstr)
@ -123,51 +68,95 @@ func (engine *Engine) MakeSession() (session Session, err error) {
if err != nil {
return Session{}, err
}
if engine.Protocol == "pgsql" {
if engine.Protocol == PQSQL {
engine.QuoteIdentifier = "\""
session = Session{Engine: engine, Db: db, ParamIteration: 1}
} else if engine.Protocol == "mssql" {
session = Session{Engine: engine, Db: db}
} else if engine.Protocol == MSSQL {
engine.QuoteIdentifier = ""
session = Session{Engine: engine, Db: db, ParamIteration: 1}
session = Session{Engine: engine, Db: db}
} else {
engine.QuoteIdentifier = "`"
session = Session{Engine: engine, Db: db, ParamIteration: 1}
session = Session{Engine: engine, Db: db}
}
session.Mapper = engine.Mapper
session.Init()
return
}
func (sqlType SQLType) genSQL(length int) string {
if sqlType == Date {
return " datetime "
func (engine *Engine) Where(querystring string, args ...interface{}) *Engine {
engine.Statement.Where(querystring, args...)
return engine
}
func (engine *Engine) Limit(limit int, start ...int) *Engine {
engine.Statement.Limit(limit, start...)
return engine
}
func (engine *Engine) OrderBy(order string) *Engine {
engine.Statement.OrderBy(order)
return engine
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(join_operator, tablename, condition string) *Engine {
engine.Statement.Join(join_operator, tablename, condition)
return engine
}
func (engine *Engine) GroupBy(keys string) *Engine {
engine.Statement.GroupBy(keys)
return engine
}
func (engine *Engine) Having(conditions string) *Engine {
engine.Statement.Having(conditions)
return engine
}
func (e *Engine) genColumnStr(col *Column) string {
sql := "`" + col.Name + "` "
if col.SQLType == Date {
sql += " datetime "
} else {
if e.Protocol == SQLITE && col.IsPrimaryKey {
sql += "integer"
} else {
sql += col.SQLType.Name
}
if e.Protocol != SQLITE {
if col.SQLType != Decimal {
sql += "(" + strconv.Itoa(col.Length) + ")"
} else {
sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")"
}
}
}
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
if col.Nullable {
sql += " NULL "
} else {
sql += " NOT NULL "
}
//fmt.Println(key)
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.IsAutoIncrement {
sql += e.AutoIncrement + " "
}
if col.IsUnique {
sql += "Unique "
}
return sql
}
func (e *Engine) genCreateSQL(table *Table) string {
sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` ("
//fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey))
for _, col := range table.Columns {
if col.Name != "" {
sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " "
if col.Nullable {
sql += " NULL "
} else {
sql += " NOT NULL "
}
//fmt.Println(key)
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.AutoIncrement {
sql += e.AutoIncrement + " "
}
if col.IsUnique {
sql += "Unique "
}
sql += ","
}
sql += e.genColumnStr(&col)
sql += ","
}
sql = sql[:len(sql)-2] + ");"
if e.ShowSQL {
@ -184,27 +173,6 @@ func (e *Engine) genDropSQL(table *Table) string {
return sql
}
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
/*
map an object into a table object
*/
@ -242,7 +210,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
case "null":
col.Nullable = (tags[j-1] != "not")
case "autoincr":
col.AutoIncrement = true
col.IsAutoIncrement = true
case "default":
col.Default = tags[j+1]
case "int":
@ -255,6 +223,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType)
col.Length = col.SQLType.DefaultLength
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
@ -266,7 +235,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if col.Name == "" {
sqlType := Type2SQLType(fieldType)
col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, true, "", false, false, false}
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", false, false, false}
}
table.Columns[col.Name] = col
if strings.ToLower(t.Field(i).Name) == "id" {
@ -278,7 +247,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if pkstr != "" {
col := table.Columns[pkstr]
col.IsPrimaryKey = true
col.AutoIncrement = true
col.IsAutoIncrement = true
col.Nullable = false
col.Length = Int.DefaultLength
table.PrimaryKey = col.Name
@ -292,7 +261,6 @@ func (engine *Engine) MapType(t reflect.Type) Table {
func (engine *Engine) Map(beans ...interface{}) (e error) {
for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; !ok {
table := engine.MapOne(bean)
@ -304,7 +272,6 @@ func (engine *Engine) Map(beans ...interface{}) (e error) {
func (engine *Engine) UnMap(beans ...interface{}) (e error) {
for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; ok {
delete(engine.Tables, tableName)
@ -380,7 +347,9 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
if err != nil {
return -1, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Insert(beans...)
}
@ -390,7 +359,9 @@ func (engine *Engine) Update(bean interface{}) (int64, error) {
if err != nil {
return -1, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Update(bean)
}
@ -400,7 +371,9 @@ func (engine *Engine) Delete(bean interface{}) (int64, error) {
if err != nil {
return -1, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Delete(bean)
}
@ -410,7 +383,9 @@ func (engine *Engine) Get(bean interface{}) error {
if err != nil {
return err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Get(bean)
}
@ -420,7 +395,9 @@ func (engine *Engine) Find(beans interface{}) error {
if err != nil {
return err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Find(beans)
}
@ -430,6 +407,8 @@ func (engine *Engine) Count(bean interface{}) (int64, error) {
if err != nil {
return 0, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Count(bean)
}

View File

@ -37,30 +37,28 @@ func Type2StructName(v reflect.Type) string {
type Session struct {
Db *sql.DB
Engine *Engine
Tx *sql.Tx
Statements []Statement
Mapper IMapper
AutoCommit bool
ParamIteration int
IsAutoCommit bool
IsAutoRollback bool
CurStatementIdx int
}
func (session *Session) Init() {
session.Statements = make([]Statement, 0)
/*session.Statement.TableName = ""
session.Statement.LimitStr = 0
session.Statement.OffsetStr = 0
session.Statement.WhereStr = ""
session.Statement.ParamStr = make([]interface{}, 0)
session.Statement.OrderStr = ""
session.Statement.JoinStr = ""
session.Statement.GroupByStr = ""
session.Statement.HavingStr = ""*/
session.CurStatementIdx = -1
session.ParamIteration = 1
session.IsAutoCommit = true
session.IsAutoRollback = false
}
func (session *Session) Close() {
rollbackfunc := func() {
if session.IsAutoRollback {
session.Rollback()
}
}
defer rollbackfunc()
defer session.Db.Close()
}
@ -78,102 +76,84 @@ func (session *Session) AutoStatement() *Statement {
return session.CurrentStatement()
}
//Execute sql
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Where(querystring interface{}, args ...interface{}) *Session {
func (session *Session) Where(querystring string, args ...interface{}) *Session {
statement := session.AutoStatement()
switch querystring := querystring.(type) {
case string:
statement.WhereStr = querystring
case int:
if session.Engine.Protocol == "pgsql" {
statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration)
} else {
statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier)
session.ParamIteration++
}
args = append(args, querystring)
}
statement.ParamStr = args
statement.Where(querystring, args...)
return session
}
func (session *Session) Limit(start int, size ...int) *Session {
session.AutoStatement().LimitStr = start
if len(size) > 0 {
session.CurrentStatement().OffsetStr = size[0]
}
return session
}
func (session *Session) Offset(offset int) *Session {
session.AutoStatement().OffsetStr = offset
func (session *Session) Limit(limit int, start ...int) *Session {
statement := session.AutoStatement()
statement.Limit(limit, start...)
return session
}
func (session *Session) OrderBy(order string) *Session {
session.AutoStatement().OrderStr = order
statement := session.AutoStatement()
statement.OrderBy(order)
return session
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(join_operator, tablename, condition string) *Session {
if session.AutoStatement().JoinStr != "" {
session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
} else {
session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
statement := session.AutoStatement()
statement.Join(join_operator, tablename, condition)
return session
}
func (session *Session) GroupBy(keys string) *Session {
session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
statement := session.AutoStatement()
statement.GroupBy(keys)
return session
}
func (session *Session) Having(conditions string) *Session {
session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions)
statement := session.AutoStatement()
statement.Having(conditions)
return session
}
func (session *Session) Begin() {
func (session *Session) Begin() error {
session.IsAutoCommit = false
session.IsAutoRollback = true
tx, err := session.Db.Begin()
session.Tx = tx
return err
}
func (session *Session) Rollback() {
func (session *Session) Rollback() error {
return session.Tx.Rollback()
}
func (session *Session) Commit() {
for _, statement := range session.Statements {
sql := statement.generateSql()
session.Exec(sql)
}
func (session *Session) Commit() error {
return session.Tx.Commit()
}
func (session *Session) TableName(bean interface{}) string {
return session.Mapper.Obj2Table(StructName(bean))
}
func (session *Session) SetStatement(statement *Statement) {
if session.CurStatementIdx == len(session.Statements)-1 {
session.Statements = append(session.Statements, *statement)
} else {
session.Statements[session.CurStatementIdx+1] = *statement
}
session.CurStatementIdx = session.CurStatementIdx + 1
}
func (session *Session) newStatement() {
state := Statement{}
state.Session = session
session.Statements = append(session.Statements, state)
session.CurStatementIdx = len(session.Statements) - 1
if session.CurStatementIdx == len(session.Statements)-1 {
state := Statement{Session: session}
state.Init()
session.Statements = append(session.Statements, state)
}
session.CurStatementIdx = session.CurStatementIdx + 1
}
func (session *Session) clearStatment() {
session.Statements[session.CurStatementIdx].Init()
session.CurStatementIdx = session.CurStatementIdx - 1
}
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
@ -250,13 +230,32 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
return nil
}
func (session *Session) Get(output interface{}) error {
//Execute sql
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Get(bean interface{}) error {
statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(output)
statement.Limit(1)
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
statement.Table = &table
colNames, args := session.BuildConditions(&table, bean)
statement.ColumnStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
resultsSlice, err := session.FindMap(statement)
if err != nil {
return err
@ -265,7 +264,7 @@ func (session *Session) Get(output interface{}) error {
return nil
} else if len(resultsSlice) == 1 {
results := resultsSlice[0]
err := session.scanMapIntoStruct(output, results)
err := session.scanMapIntoStruct(bean, results)
if err != nil {
return err
}
@ -277,12 +276,15 @@ func (session *Session) Get(output interface{}) error {
func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
statement.Table = &table
resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr)
colNames, args := session.BuildConditions(&table, bean)
statement.ColumnStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
resultsSlice, err := session.SQL2Map(statement.genCountSql(), append(statement.Params, statement.BeanArgs...))
if err != nil {
return 0, err
}
@ -346,6 +348,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
}
for res.Next() {
result := make(map[string][]byte)
//scanResultContainers := make([]interface{}, len(fields))
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
@ -367,7 +370,6 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
@ -397,7 +399,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) {
sqls := statement.generateSql()
return session.SQL2Map(sqls, statement.ParamStr)
return session.SQL2Map(sqls, append(statement.Params, statement.BeanArgs...))
}
func (session *Session) Insert(beans ...interface{}) (int64, error) {
@ -421,7 +423,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
val := fieldValue.Interface()
if col.AutoIncrement {
if col.IsAutoIncrement {
if fieldValue.Int() == 0 {
continue
}
@ -442,7 +444,14 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
var res sql.Result
var err error
if session.IsAutoCommit {
res, err = session.Exec(statement, args...)
} else {
res, err = session.Tx.Exec(statement, args...)
}
if err != nil {
return -1, err
}
@ -455,16 +464,10 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return id, nil
}
func (session *Session) Update(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
func (session *Session) BuildConditions(table *Table, bean interface{}) ([]string, []interface{}) {
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
@ -491,9 +494,18 @@ func (session *Session) Update(bean interface{}) (int64, error) {
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
return colNames, args
}
func (session *Session) Update(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames, args := session.BuildConditions(&table, bean)
var condition = ""
st := session.CurrentStatement()
if st != nil && st.WhereStr != "" {
st := session.AutoStatement()
defer session.clearStatment()
if st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr)
}
@ -516,7 +528,15 @@ func (session *Session) Update(bean interface{}) (int64, error) {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
var res sql.Result
var err error
if session.IsAutoCommit {
fmt.Println("session.Exec")
res, err = session.Exec(statement, append(args, st.Params...)...)
} else {
fmt.Println("tx.Exec")
res, err = session.Tx.Exec(statement, append(args, st.Params...)...)
}
if err != nil {
return -1, err
}
@ -532,42 +552,12 @@ func (session *Session) Update(bean interface{}) (int64, error) {
func (session *Session) Delete(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
switch fieldType.Kind() {
case reflect.String:
if fieldValue.String() == "" {
continue
}
case reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time)
if t.IsZero() {
continue
}
}
default:
continue
}
args = append(args, val)
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
colNames, args := session.BuildConditions(&table, bean)
var condition = ""
st := session.CurrentStatement()
if st != nil && st.WhereStr != "" {
st := session.AutoStatement()
defer session.clearStatment()
if st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr)
if len(colNames) > 0 {
condition += " and "
@ -587,7 +577,13 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
var res sql.Result
var err error
if session.IsAutoCommit {
res, err = session.Exec(statement, append(st.Params, args...)...)
} else {
res, err = session.Tx.Exec(statement, append(st.Params, args...)...)
}
if err != nil {
return -1, err
}

View File

@ -5,40 +5,66 @@ import (
)
type Statement struct {
TableName string
Table *Table
Session *Session
LimitStr int
OffsetStr int
Start int
LimitN int
WhereStr string
ParamStr []interface{}
Params []interface{}
OrderStr string
JoinStr string
GroupByStr string
HavingStr string
ColumnStr string
BeanArgs []interface{}
}
func (statement *Statement) Limit(start int, size ...int) *Statement {
statement.LimitStr = start
if len(size) > 0 {
statement.OffsetStr = size[0]
func (statement *Statement) Init() {
statement.Table = nil
statement.Session = nil
statement.Start = 0
statement.LimitN = 0
statement.WhereStr = ""
statement.Params = make([]interface{}, 0)
statement.OrderStr = ""
statement.JoinStr = ""
statement.GroupByStr = ""
statement.HavingStr = ""
statement.ColumnStr = ""
statement.BeanArgs = make([]interface{}, 0)
}
func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.WhereStr = querystring
statement.Params = args
}
func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit
if len(start) > 0 {
statement.Start = start[0]
}
return statement
}
func (statement *Statement) Offset(offset int) *Statement {
statement.OffsetStr = offset
return statement
}
func (statement *Statement) OrderBy(order string) *Statement {
func (statement *Statement) OrderBy(order string) {
statement.OrderStr = order
return statement
}
func (statement *Statement) Select(colums string) *Statement {
//statement.ColumnStr = colums
return statement
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(join_operator, tablename, condition string) {
if statement.JoinStr != "" {
statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
} else {
statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
}
func (statement *Statement) GroupBy(keys string) {
statement.GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
}
func (statement *Statement) Having(conditions string) {
statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
}
func (statement Statement) generateSql() string {
@ -50,27 +76,41 @@ func (statement Statement) genCountSql() string {
return statement.genSelectSql("count(*) as total")
}
func (statement Statement) genExecSql() string {
return ""
}
func (statement Statement) genSelectSql(columnStr string) (a string) {
session := statement.Session
if session.Engine.Protocol == "mssql" {
if statement.OffsetStr > 0 {
if statement.Start > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
statement.Table.PKColumn().Name,
columnStr,
statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
a = fmt.Sprintf("select %v from (%v) "+
"as a where rownum between %v and %v",
columnStr,
a,
statement.OffsetStr,
statement.LimitStr)
} else if statement.LimitStr > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name)
statement.Start,
statement.LimitN)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -85,6 +125,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -103,6 +148,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
}
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -113,10 +163,10 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
if statement.OffsetStr > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr)
} else if statement.LimitStr > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr)
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.Start, statement.LimitN)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
}
return

90
table.go Normal file
View File

@ -0,0 +1,90 @@
package xorm
import (
"reflect"
"strconv"
"strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
var (
Int = SQLType{"int", 11, 0}
Char = SQLType{"char", 1, 0}
Bool = SQLType{"int", 1, 0}
Varchar = SQLType{"varchar", 50, 0}
Date = SQLType{"date", 24, 0}
Decimal = SQLType{"decimal", 26, 2}
Float = SQLType{"float", 31, 0}
Double = SQLType{"double", 31, 0}
)
func (sqlType SQLType) genSQL(length int) string {
if sqlType == Date {
return " datetime "
}
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.Bool:
st = Bool
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Length2 int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
IsAutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
/*func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}*/
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}

View File

@ -3,8 +3,7 @@ package xorm_test
import (
"fmt"
_ "github.com/Go-SQL-Driver/MySQL"
//_ "github.com/ziutek/mymysql/godrv"
//_ "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
"testing"
"time"
"xorm"
@ -36,21 +35,14 @@ type Userinfo struct {
var engine xorm.Engine
func TestCreateEngine(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
//engine = orm.Create("mymysql://root:123@localhost/test")
//engine = orm.Create("sqlite:///test.db")
engine.ShowSQL = true
}
func TestDirectCreateTable(t *testing.T) {
func directCreateTable(t *testing.T) {
err := engine.CreateTables(&Userinfo{})
if err != nil {
t.Error(err)
}
}
func TestMapper(t *testing.T) {
func mapper(t *testing.T) {
err := engine.UnMap(&Userinfo{})
if err != nil {
t.Error(err)
@ -72,7 +64,7 @@ func TestMapper(t *testing.T) {
}
}
func TestInsert(t *testing.T) {
func insert(t *testing.T) {
user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()}
_, err := engine.Insert(&user)
if err != nil {
@ -80,7 +72,7 @@ func TestInsert(t *testing.T) {
}
}
func TestInsertAutoIncr(t *testing.T) {
func insertAutoIncr(t *testing.T) {
// auto increment insert
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err := engine.Insert(&user)
@ -89,7 +81,7 @@ func TestInsertAutoIncr(t *testing.T) {
}
}
func TestInsertMulti(t *testing.T) {
func insertMulti(t *testing.T) {
user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}
user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}
_, err := engine.Insert(&user1, &user2)
@ -98,7 +90,7 @@ func TestInsertMulti(t *testing.T) {
}
}
func TestUpdate(t *testing.T) {
func update(t *testing.T) {
// update by id
user := Userinfo{Uid: 1, Username: "xxx"}
_, err := engine.Update(&user)
@ -107,7 +99,7 @@ func TestUpdate(t *testing.T) {
}
}
func TestDelete(t *testing.T) {
func delete(t *testing.T) {
user := Userinfo{Uid: 1}
_, err := engine.Delete(&user)
if err != nil {
@ -115,7 +107,7 @@ func TestDelete(t *testing.T) {
}
}
func TestGet(t *testing.T) {
func get(t *testing.T) {
user := Userinfo{Uid: 2}
err := engine.Get(&user)
@ -125,7 +117,7 @@ func TestGet(t *testing.T) {
fmt.Println(user)
}
func TestFind(t *testing.T) {
func find(t *testing.T) {
users := make([]Userinfo, 0)
err := engine.Find(&users)
@ -135,8 +127,8 @@ func TestFind(t *testing.T) {
fmt.Println(users)
}
func TestCount(t *testing.T) {
user := Userinfo{}
func count(t *testing.T) {
user := Userinfo{Departname: "dev"}
total, err := engine.Count(&user)
if err != nil {
t.Error(err)
@ -144,47 +136,117 @@ func TestCount(t *testing.T) {
fmt.Printf("Total %d records!!!", total)
}
func TestWhere(t *testing.T) {
func where(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Where("id > ?", 2).Find(&users)
err := engine.Where("id > ?", 2).Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestLimit(t *testing.T) {
func limit(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Limit(2, 1).Find(&users)
err := engine.Limit(2, 1).Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestOrder(t *testing.T) {
func order(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.OrderBy("id desc").Find(&users)
err := engine.OrderBy("id desc").Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestTransaction(*testing.T) {
func transaction(t *testing.T) {
counter := func() {
total, err := engine.Count(&Userinfo{})
if err != nil {
t.Error(err)
}
fmt.Printf("----now total %v records\n", total)
}
counter()
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
return
}
defer counter()
session.Begin()
session.IsAutoRollback = true
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
t.Error(err)
return
}
user2 := Userinfo{Username: "yyy"}
_, err = session.Where("id = ?", 2).Update(&user2)
if err != nil {
t.Error(err)
return
}
_, err = session.Delete(&user2)
if err != nil {
t.Error(err)
return
}
err = session.Commit()
if err != nil {
t.Error(err)
return
}
}
func TestMysql(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
engine.ShowSQL = true
directCreateTable(t)
mapper(t)
insert(t)
insertAutoIncr(t)
insertMulti(t)
update(t)
delete(t)
get(t)
find(t)
count(t)
where(t)
limit(t)
order(t)
transaction(t)
}
func TestSqlite(t *testing.T) {
engine = xorm.Create("sqlite:///test.db")
engine.ShowSQL = true
directCreateTable(t)
mapper(t)
insert(t)
insertAutoIncr(t)
insertMulti(t)
update(t)
delete(t)
get(t)
find(t)
count(t)
where(t)
limit(t)
order(t)
transaction(t)
}