fix mysql test and sqlite test

This commit is contained in:
Lunny Xiao 2013-05-06 16:01:17 +08:00
parent 95927253d4
commit 74e0e3b175
7 changed files with 580 additions and 392 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.o
*.a
*.so
*.db
# Folders
_obj

View File

@ -1,5 +1,5 @@
xorm
=====
# xorm
===========
[中文](README_CN.md)
@ -17,71 +17,72 @@ Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
### Installing xorm
## Installing xorm
go get github.com/lunny/xorm
### Quick Start
## Quick Start
1.Create an database engine (for example: mysql)
1.Create a database engine (for example: mysql)
```go
engine := xorm.Create("mysql://root:123@localhost/test")
```
2.Define your struct
```go
type User struct {
Id int
Name string
Age int `xorm:"-"`
}
```
for Simple Task, just use engine's functions:
begin start, you should create a database and then we create the tables
before beginning, you should create a database in mysql and then we will create the tables.
```go
err := engine.CreateTables(&User{})
```
then, insert an struct to table
```go
id, err := engine.Insert(&User{Name:"lunny"})
```
or you want to update this struct
```go
user := User{Id:1, Name:"xlw"}
rows, err := engine.Update(&user)
```
3.Fetch a single object by user
```go
var user = User{Id:27}
engine.Get(&user)
var user = User{Name:"xlw"}
engine.Get(&user)
```
##Deep Use
for deep use, you should create a session, this func will create a connection to db
```go
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return
}
```
1.Fetch a single object by where
```go
var user Userinfo
session.Where("id=?", 27).Get(&user)
@ -93,11 +94,11 @@ session.Where("name = ?", "john").Get(&user3) // more complex query
var user4 Userinfo
session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex
```
2.Fetch multiple objects
```go
var allusers []Userinfo
err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20
@ -106,9 +107,9 @@ err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 1
var everyone []Userinfo
err := session.Find(&everyone)
```
###***About Map Rules***
##***Mapping Rules***
1.Struct and struct's fields name should be Pascal style, and the table and column's name default is us
for example:
The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname.
@ -117,13 +118,22 @@ If the keyname is 'UserName' will turn into the select colum 'user_name'
2.You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper
another is use field tag, field tag support the below keywords:
[name] column name
pk the field is a primary key
int(11)/varchar(50) column type
autoincr auto incrment
[not ]null if column can be null value
unique unique
- this field is not map as a table column
* [name] column name
* pk the field is a primary key
* int(11)/varchar(50) column type
* autoincr auto incrment
* [not ]null if column can be null value
* unique unique
* \- this field is not map as a table column
##FAQ
1.How the xorm tag use both with json?
use space
type User struct {
User string `json:"user" orm:"user_id"`
}
## LICENSE

207
engine.go
View File

@ -6,22 +6,6 @@ import (
"reflect"
"strconv"
"strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
}
var (
Int = SQLType{"int", 11}
Char = SQLType{"char", 1}
Varchar = SQLType{"varchar", 50}
Date = SQLType{"date", 24}
Decimal = SQLType{"decimal", 26}
Float = SQLType{"float", 31}
Double = SQLType{"double", 31}
)
const (
@ -32,51 +16,6 @@ const (
MYMYSQL = "mymysql"
)
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
AutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}
type Engine struct {
Mapper IMapper
Protocol string
@ -91,21 +30,27 @@ type Engine struct {
AutoIncrement string
ShowSQL bool
QuoteIdentifier string
Statement Statement
}
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
}
func (e *Engine) OpenDB() (db *sql.DB, err error) {
db = nil
err = nil
if e.Protocol == "sqlite" {
if e.Protocol == SQLITE {
// 'sqlite:///foo.db'
db, err = sql.Open("sqlite3", e.Others)
// 'sqlite:///:memory:'
} else if e.Protocol == "mysql" {
} else if e.Protocol == MYSQL {
// 'mysql://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
connstr := strings.Join([]string{e.UserName, ":",
e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "")
db, err = sql.Open(e.Protocol, connstr)
} else if e.Protocol == "mymysql" {
} else if e.Protocol == MYMYSQL {
// DBNAME/USER/PASSWD
connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/")
db, err = sql.Open(e.Protocol, connstr)
@ -123,34 +68,71 @@ func (engine *Engine) MakeSession() (session Session, err error) {
if err != nil {
return Session{}, err
}
if engine.Protocol == "pgsql" {
if engine.Protocol == PQSQL {
engine.QuoteIdentifier = "\""
session = Session{Engine: engine, Db: db, ParamIteration: 1}
} else if engine.Protocol == "mssql" {
session = Session{Engine: engine, Db: db}
} else if engine.Protocol == MSSQL {
engine.QuoteIdentifier = ""
session = Session{Engine: engine, Db: db, ParamIteration: 1}
session = Session{Engine: engine, Db: db}
} else {
engine.QuoteIdentifier = "`"
session = Session{Engine: engine, Db: db, ParamIteration: 1}
session = Session{Engine: engine, Db: db}
}
session.Mapper = engine.Mapper
session.Init()
return
}
func (sqlType SQLType) genSQL(length int) string {
if sqlType == Date {
return " datetime "
}
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
func (engine *Engine) Where(querystring string, args ...interface{}) *Engine {
engine.Statement.Where(querystring, args...)
return engine
}
func (engine *Engine) Limit(limit int, start ...int) *Engine {
engine.Statement.Limit(limit, start...)
return engine
}
func (engine *Engine) OrderBy(order string) *Engine {
engine.Statement.OrderBy(order)
return engine
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (engine *Engine) Join(join_operator, tablename, condition string) *Engine {
engine.Statement.Join(join_operator, tablename, condition)
return engine
}
func (engine *Engine) GroupBy(keys string) *Engine {
engine.Statement.GroupBy(keys)
return engine
}
func (engine *Engine) Having(conditions string) *Engine {
engine.Statement.Having(conditions)
return engine
}
func (e *Engine) genColumnStr(col *Column) string {
sql := "`" + col.Name + "` "
if col.SQLType == Date {
sql += " datetime "
} else {
if e.Protocol == SQLITE && col.IsPrimaryKey {
sql += "integer"
} else {
sql += col.SQLType.Name
}
if e.Protocol != SQLITE {
if col.SQLType != Decimal {
sql += "(" + strconv.Itoa(col.Length) + ")"
} else {
sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")"
}
}
}
func (e *Engine) genCreateSQL(table *Table) string {
sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` ("
//fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey))
for _, col := range table.Columns {
if col.Name != "" {
sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " "
if col.Nullable {
sql += " NULL "
} else {
@ -160,14 +142,21 @@ func (e *Engine) genCreateSQL(table *Table) string {
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.AutoIncrement {
if col.IsAutoIncrement {
sql += e.AutoIncrement + " "
}
if col.IsUnique {
sql += "Unique "
}
sql += ","
return sql
}
func (e *Engine) genCreateSQL(table *Table) string {
sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` ("
//fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey))
for _, col := range table.Columns {
sql += e.genColumnStr(&col)
sql += ","
}
sql = sql[:len(sql)-2] + ");"
if e.ShowSQL {
@ -184,27 +173,6 @@ func (e *Engine) genDropSQL(table *Table) string {
return sql
}
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
/*
map an object into a table object
*/
@ -242,7 +210,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
case "null":
col.Nullable = (tags[j-1] != "not")
case "autoincr":
col.AutoIncrement = true
col.IsAutoIncrement = true
case "default":
col.Default = tags[j+1]
case "int":
@ -255,6 +223,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType)
col.Length = col.SQLType.DefaultLength
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
@ -266,7 +235,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if col.Name == "" {
sqlType := Type2SQLType(fieldType)
col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, true, "", false, false, false}
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", false, false, false}
}
table.Columns[col.Name] = col
if strings.ToLower(t.Field(i).Name) == "id" {
@ -278,7 +247,7 @@ func (engine *Engine) MapType(t reflect.Type) Table {
if pkstr != "" {
col := table.Columns[pkstr]
col.IsPrimaryKey = true
col.AutoIncrement = true
col.IsAutoIncrement = true
col.Nullable = false
col.Length = Int.DefaultLength
table.PrimaryKey = col.Name
@ -292,7 +261,6 @@ func (engine *Engine) MapType(t reflect.Type) Table {
func (engine *Engine) Map(beans ...interface{}) (e error) {
for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; !ok {
table := engine.MapOne(bean)
@ -304,7 +272,6 @@ func (engine *Engine) Map(beans ...interface{}) (e error) {
func (engine *Engine) UnMap(beans ...interface{}) (e error) {
for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; ok {
delete(engine.Tables, tableName)
@ -380,7 +347,9 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
if err != nil {
return -1, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Insert(beans...)
}
@ -390,7 +359,9 @@ func (engine *Engine) Update(bean interface{}) (int64, error) {
if err != nil {
return -1, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Update(bean)
}
@ -400,7 +371,9 @@ func (engine *Engine) Delete(bean interface{}) (int64, error) {
if err != nil {
return -1, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Delete(bean)
}
@ -410,7 +383,9 @@ func (engine *Engine) Get(bean interface{}) error {
if err != nil {
return err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Get(bean)
}
@ -420,7 +395,9 @@ func (engine *Engine) Find(beans interface{}) error {
if err != nil {
return err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Find(beans)
}
@ -430,6 +407,8 @@ func (engine *Engine) Count(bean interface{}) (int64, error) {
if err != nil {
return 0, err
}
defer engine.Statement.Init()
engine.Statement.Session = &session
session.SetStatement(&engine.Statement)
return session.Count(bean)
}

View File

@ -37,30 +37,28 @@ func Type2StructName(v reflect.Type) string {
type Session struct {
Db *sql.DB
Engine *Engine
Tx *sql.Tx
Statements []Statement
Mapper IMapper
AutoCommit bool
ParamIteration int
IsAutoCommit bool
IsAutoRollback bool
CurStatementIdx int
}
func (session *Session) Init() {
session.Statements = make([]Statement, 0)
/*session.Statement.TableName = ""
session.Statement.LimitStr = 0
session.Statement.OffsetStr = 0
session.Statement.WhereStr = ""
session.Statement.ParamStr = make([]interface{}, 0)
session.Statement.OrderStr = ""
session.Statement.JoinStr = ""
session.Statement.GroupByStr = ""
session.Statement.HavingStr = ""*/
session.CurStatementIdx = -1
session.ParamIteration = 1
session.IsAutoCommit = true
session.IsAutoRollback = false
}
func (session *Session) Close() {
rollbackfunc := func() {
if session.IsAutoRollback {
session.Rollback()
}
}
defer rollbackfunc()
defer session.Db.Close()
}
@ -78,102 +76,84 @@ func (session *Session) AutoStatement() *Statement {
return session.CurrentStatement()
}
//Execute sql
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Where(querystring interface{}, args ...interface{}) *Session {
func (session *Session) Where(querystring string, args ...interface{}) *Session {
statement := session.AutoStatement()
switch querystring := querystring.(type) {
case string:
statement.WhereStr = querystring
case int:
if session.Engine.Protocol == "pgsql" {
statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration)
} else {
statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier)
session.ParamIteration++
}
args = append(args, querystring)
}
statement.ParamStr = args
statement.Where(querystring, args...)
return session
}
func (session *Session) Limit(start int, size ...int) *Session {
session.AutoStatement().LimitStr = start
if len(size) > 0 {
session.CurrentStatement().OffsetStr = size[0]
}
return session
}
func (session *Session) Offset(offset int) *Session {
session.AutoStatement().OffsetStr = offset
func (session *Session) Limit(limit int, start ...int) *Session {
statement := session.AutoStatement()
statement.Limit(limit, start...)
return session
}
func (session *Session) OrderBy(order string) *Session {
session.AutoStatement().OrderStr = order
statement := session.AutoStatement()
statement.OrderBy(order)
return session
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(join_operator, tablename, condition string) *Session {
if session.AutoStatement().JoinStr != "" {
session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
} else {
session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
statement := session.AutoStatement()
statement.Join(join_operator, tablename, condition)
return session
}
func (session *Session) GroupBy(keys string) *Session {
session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
statement := session.AutoStatement()
statement.GroupBy(keys)
return session
}
func (session *Session) Having(conditions string) *Session {
session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions)
statement := session.AutoStatement()
statement.Having(conditions)
return session
}
func (session *Session) Begin() {
func (session *Session) Begin() error {
session.IsAutoCommit = false
session.IsAutoRollback = true
tx, err := session.Db.Begin()
session.Tx = tx
return err
}
func (session *Session) Rollback() {
func (session *Session) Rollback() error {
return session.Tx.Rollback()
}
func (session *Session) Commit() {
for _, statement := range session.Statements {
sql := statement.generateSql()
session.Exec(sql)
}
func (session *Session) Commit() error {
return session.Tx.Commit()
}
func (session *Session) TableName(bean interface{}) string {
return session.Mapper.Obj2Table(StructName(bean))
}
func (session *Session) SetStatement(statement *Statement) {
if session.CurStatementIdx == len(session.Statements)-1 {
session.Statements = append(session.Statements, *statement)
} else {
session.Statements[session.CurStatementIdx+1] = *statement
}
session.CurStatementIdx = session.CurStatementIdx + 1
}
func (session *Session) newStatement() {
state := Statement{}
state.Session = session
if session.CurStatementIdx == len(session.Statements)-1 {
state := Statement{Session: session}
state.Init()
session.Statements = append(session.Statements, state)
session.CurStatementIdx = len(session.Statements) - 1
}
session.CurStatementIdx = session.CurStatementIdx + 1
}
func (session *Session) clearStatment() {
session.Statements[session.CurStatementIdx].Init()
session.CurStatementIdx = session.CurStatementIdx - 1
}
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
@ -250,13 +230,32 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b
return nil
}
func (session *Session) Get(output interface{}) error {
//Execute sql
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Get(bean interface{}) error {
statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(output)
statement.Limit(1)
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
statement.Table = &table
colNames, args := session.BuildConditions(&table, bean)
statement.ColumnStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
resultsSlice, err := session.FindMap(statement)
if err != nil {
return err
@ -265,7 +264,7 @@ func (session *Session) Get(output interface{}) error {
return nil
} else if len(resultsSlice) == 1 {
results := resultsSlice[0]
err := session.scanMapIntoStruct(output, results)
err := session.scanMapIntoStruct(bean, results)
if err != nil {
return err
}
@ -277,12 +276,15 @@ func (session *Session) Get(output interface{}) error {
func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
statement.Table = &table
resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr)
colNames, args := session.BuildConditions(&table, bean)
statement.ColumnStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
resultsSlice, err := session.SQL2Map(statement.genCountSql(), append(statement.Params, statement.BeanArgs...))
if err != nil {
return 0, err
}
@ -346,6 +348,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
}
for res.Next() {
result := make(map[string][]byte)
//scanResultContainers := make([]interface{}, len(fields))
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
@ -367,7 +370,6 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
@ -397,7 +399,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli
func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) {
sqls := statement.generateSql()
return session.SQL2Map(sqls, statement.ParamStr)
return session.SQL2Map(sqls, append(statement.Params, statement.BeanArgs...))
}
func (session *Session) Insert(beans ...interface{}) (int64, error) {
@ -421,7 +423,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
val := fieldValue.Interface()
if col.AutoIncrement {
if col.IsAutoIncrement {
if fieldValue.Int() == 0 {
continue
}
@ -442,7 +444,14 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
var res sql.Result
var err error
if session.IsAutoCommit {
res, err = session.Exec(statement, args...)
} else {
res, err = session.Tx.Exec(statement, args...)
}
if err != nil {
return -1, err
}
@ -455,16 +464,10 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) {
return id, nil
}
func (session *Session) Update(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
func (session *Session) BuildConditions(table *Table, bean interface{}) ([]string, []interface{}) {
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
@ -491,9 +494,18 @@ func (session *Session) Update(bean interface{}) (int64, error) {
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
return colNames, args
}
func (session *Session) Update(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames, args := session.BuildConditions(&table, bean)
var condition = ""
st := session.CurrentStatement()
if st != nil && st.WhereStr != "" {
st := session.AutoStatement()
defer session.clearStatment()
if st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr)
}
@ -516,7 +528,15 @@ func (session *Session) Update(bean interface{}) (int64, error) {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
var res sql.Result
var err error
if session.IsAutoCommit {
fmt.Println("session.Exec")
res, err = session.Exec(statement, append(args, st.Params...)...)
} else {
fmt.Println("tx.Exec")
res, err = session.Tx.Exec(statement, append(args, st.Params...)...)
}
if err != nil {
return -1, err
}
@ -532,42 +552,12 @@ func (session *Session) Update(bean interface{}) (int64, error) {
func (session *Session) Delete(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
switch fieldType.Kind() {
case reflect.String:
if fieldValue.String() == "" {
continue
}
case reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time)
if t.IsZero() {
continue
}
}
default:
continue
}
args = append(args, val)
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
colNames, args := session.BuildConditions(&table, bean)
var condition = ""
st := session.CurrentStatement()
if st != nil && st.WhereStr != "" {
st := session.AutoStatement()
defer session.clearStatment()
if st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr)
if len(colNames) > 0 {
condition += " and "
@ -587,7 +577,13 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
var res sql.Result
var err error
if session.IsAutoCommit {
res, err = session.Exec(statement, append(st.Params, args...)...)
} else {
res, err = session.Tx.Exec(statement, append(st.Params, args...)...)
}
if err != nil {
return -1, err
}

View File

@ -5,40 +5,66 @@ import (
)
type Statement struct {
TableName string
Table *Table
Session *Session
LimitStr int
OffsetStr int
Start int
LimitN int
WhereStr string
ParamStr []interface{}
Params []interface{}
OrderStr string
JoinStr string
GroupByStr string
HavingStr string
ColumnStr string
BeanArgs []interface{}
}
func (statement *Statement) Limit(start int, size ...int) *Statement {
statement.LimitStr = start
if len(size) > 0 {
statement.OffsetStr = size[0]
}
return statement
func (statement *Statement) Init() {
statement.Table = nil
statement.Session = nil
statement.Start = 0
statement.LimitN = 0
statement.WhereStr = ""
statement.Params = make([]interface{}, 0)
statement.OrderStr = ""
statement.JoinStr = ""
statement.GroupByStr = ""
statement.HavingStr = ""
statement.ColumnStr = ""
statement.BeanArgs = make([]interface{}, 0)
}
func (statement *Statement) Offset(offset int) *Statement {
statement.OffsetStr = offset
return statement
func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.WhereStr = querystring
statement.Params = args
}
func (statement *Statement) OrderBy(order string) *Statement {
func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit
if len(start) > 0 {
statement.Start = start[0]
}
}
func (statement *Statement) OrderBy(order string) {
statement.OrderStr = order
return statement
}
func (statement *Statement) Select(colums string) *Statement {
//statement.ColumnStr = colums
return statement
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (statement *Statement) Join(join_operator, tablename, condition string) {
if statement.JoinStr != "" {
statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
} else {
statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
}
func (statement *Statement) GroupBy(keys string) {
statement.GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
}
func (statement *Statement) Having(conditions string) {
statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
}
func (statement Statement) generateSql() string {
@ -50,27 +76,41 @@ func (statement Statement) genCountSql() string {
return statement.genSelectSql("count(*) as total")
}
func (statement Statement) genExecSql() string {
return ""
}
func (statement Statement) genSelectSql(columnStr string) (a string) {
session := statement.Session
if session.Engine.Protocol == "mssql" {
if statement.OffsetStr > 0 {
if statement.Start > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
statement.Table.PKColumn().Name,
columnStr,
statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
a = fmt.Sprintf("select %v from (%v) "+
"as a where rownum between %v and %v",
columnStr,
a,
statement.OffsetStr,
statement.LimitStr)
} else if statement.LimitStr > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name)
statement.Start,
statement.LimitN)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -85,6 +125,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -103,6 +148,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
}
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ColumnStr != "" {
a = fmt.Sprintf("%v and %v", a, statement.ColumnStr)
}
} else if statement.ColumnStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
@ -113,10 +163,10 @@ func (statement Statement) genSelectSql(columnStr string) (a string) {
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
if statement.OffsetStr > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr)
} else if statement.LimitStr > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr)
if statement.Start > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.Start, statement.LimitN)
} else if statement.LimitN > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN)
}
}
return

90
table.go Normal file
View File

@ -0,0 +1,90 @@
package xorm
import (
"reflect"
"strconv"
"strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
DefaultLength2 int
}
var (
Int = SQLType{"int", 11, 0}
Char = SQLType{"char", 1, 0}
Bool = SQLType{"int", 1, 0}
Varchar = SQLType{"varchar", 50, 0}
Date = SQLType{"date", 24, 0}
Decimal = SQLType{"decimal", 26, 2}
Float = SQLType{"float", 31, 0}
Double = SQLType{"double", 31, 0}
)
func (sqlType SQLType) genSQL(length int) string {
if sqlType == Date {
return " datetime "
}
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.Bool:
st = Bool
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Length2 int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
IsAutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
/*func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}*/
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}

View File

@ -3,8 +3,7 @@ package xorm_test
import (
"fmt"
_ "github.com/Go-SQL-Driver/MySQL"
//_ "github.com/ziutek/mymysql/godrv"
//_ "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
"testing"
"time"
"xorm"
@ -36,21 +35,14 @@ type Userinfo struct {
var engine xorm.Engine
func TestCreateEngine(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
//engine = orm.Create("mymysql://root:123@localhost/test")
//engine = orm.Create("sqlite:///test.db")
engine.ShowSQL = true
}
func TestDirectCreateTable(t *testing.T) {
func directCreateTable(t *testing.T) {
err := engine.CreateTables(&Userinfo{})
if err != nil {
t.Error(err)
}
}
func TestMapper(t *testing.T) {
func mapper(t *testing.T) {
err := engine.UnMap(&Userinfo{})
if err != nil {
t.Error(err)
@ -72,7 +64,7 @@ func TestMapper(t *testing.T) {
}
}
func TestInsert(t *testing.T) {
func insert(t *testing.T) {
user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()}
_, err := engine.Insert(&user)
if err != nil {
@ -80,7 +72,7 @@ func TestInsert(t *testing.T) {
}
}
func TestInsertAutoIncr(t *testing.T) {
func insertAutoIncr(t *testing.T) {
// auto increment insert
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err := engine.Insert(&user)
@ -89,7 +81,7 @@ func TestInsertAutoIncr(t *testing.T) {
}
}
func TestInsertMulti(t *testing.T) {
func insertMulti(t *testing.T) {
user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}
user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}
_, err := engine.Insert(&user1, &user2)
@ -98,7 +90,7 @@ func TestInsertMulti(t *testing.T) {
}
}
func TestUpdate(t *testing.T) {
func update(t *testing.T) {
// update by id
user := Userinfo{Uid: 1, Username: "xxx"}
_, err := engine.Update(&user)
@ -107,7 +99,7 @@ func TestUpdate(t *testing.T) {
}
}
func TestDelete(t *testing.T) {
func delete(t *testing.T) {
user := Userinfo{Uid: 1}
_, err := engine.Delete(&user)
if err != nil {
@ -115,7 +107,7 @@ func TestDelete(t *testing.T) {
}
}
func TestGet(t *testing.T) {
func get(t *testing.T) {
user := Userinfo{Uid: 2}
err := engine.Get(&user)
@ -125,7 +117,7 @@ func TestGet(t *testing.T) {
fmt.Println(user)
}
func TestFind(t *testing.T) {
func find(t *testing.T) {
users := make([]Userinfo, 0)
err := engine.Find(&users)
@ -135,8 +127,8 @@ func TestFind(t *testing.T) {
fmt.Println(users)
}
func TestCount(t *testing.T) {
user := Userinfo{}
func count(t *testing.T) {
user := Userinfo{Departname: "dev"}
total, err := engine.Count(&user)
if err != nil {
t.Error(err)
@ -144,47 +136,117 @@ func TestCount(t *testing.T) {
fmt.Printf("Total %d records!!!", total)
}
func TestWhere(t *testing.T) {
func where(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Where("id > ?", 2).Find(&users)
err := engine.Where("id > ?", 2).Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestLimit(t *testing.T) {
func limit(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Limit(2, 1).Find(&users)
err := engine.Limit(2, 1).Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestOrder(t *testing.T) {
func order(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.OrderBy("id desc").Find(&users)
err := engine.OrderBy("id desc").Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestTransaction(*testing.T) {
func transaction(t *testing.T) {
counter := func() {
total, err := engine.Count(&Userinfo{})
if err != nil {
t.Error(err)
}
fmt.Printf("----now total %v records\n", total)
}
counter()
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
return
}
defer counter()
session.Begin()
session.IsAutoRollback = true
user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err = session.Insert(&user1)
if err != nil {
t.Error(err)
return
}
user2 := Userinfo{Username: "yyy"}
_, err = session.Where("id = ?", 2).Update(&user2)
if err != nil {
t.Error(err)
return
}
_, err = session.Delete(&user2)
if err != nil {
t.Error(err)
return
}
err = session.Commit()
if err != nil {
t.Error(err)
return
}
}
func TestMysql(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
engine.ShowSQL = true
directCreateTable(t)
mapper(t)
insert(t)
insertAutoIncr(t)
insertMulti(t)
update(t)
delete(t)
get(t)
find(t)
count(t)
where(t)
limit(t)
order(t)
transaction(t)
}
func TestSqlite(t *testing.T) {
engine = xorm.Create("sqlite:///test.db")
engine.ShowSQL = true
directCreateTable(t)
mapper(t)
insert(t)
insertAutoIncr(t)
insertMulti(t)
update(t)
delete(t)
get(t)
find(t)
count(t)
where(t)
limit(t)
order(t)
transaction(t)
}