use default connection params

This commit is contained in:
Lunny Xiao 2013-05-09 09:56:58 +08:00
parent 4ebee70f92
commit 09848afbcf
7 changed files with 82 additions and 104 deletions

View File

@ -1,11 +1,11 @@
# xorm # xorm
=========== ===========
[中文](README_CN.md) [中文](./README_CN.md)
xorm is an ORM for Go. It lets you map Go structs to tables in a database. xorm is an ORM for Go. It lets you map Go structs to tables in a database.
Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. Right now, it supports Mysql and SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future.
All in all, it's not entirely ready for product use yet, but it's getting there. All in all, it's not entirely ready for product use yet, but it's getting there.
@ -21,9 +21,13 @@ SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
## Quick Start ## Quick Start
1.Create a database engine (for example: mysql) 1.Create a database engine just like sql.OpenDB (for example: mysql)
engine := xorm.Create("mysql://root:123@localhost/test") engine := xorm.Create("mysql", "root:123@/test?charset=utf8")
or
engine = xorm.Create("sqlite3", "./test.db")
2.Define your struct 2.Define your struct

View File

@ -21,9 +21,13 @@ SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
## 快速开始 ## 快速开始
1.创建数据库引擎 (例如: mysql) 1.创建数据库引擎这个函数的参数和sql.OpenDB相同但不会立即创建连接 (例如: mysql)
engine := xorm.Create("mysql://root:123@localhost/test") engine := xorm.Create("mysql", "root:123@/test?charset=utf8")
or
engine = xorm.Create("sqlite3", "./test.db")
2.定义你的Struct 2.定义你的Struct

View File

@ -2,7 +2,7 @@ package xorm
import ( import (
"database/sql" "database/sql"
"fmt" //"fmt"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
@ -11,21 +11,15 @@ import (
const ( const (
PQSQL = "pqsql" PQSQL = "pqsql"
MSSQL = "mssql" MSSQL = "mssql"
SQLITE = "sqlite" SQLITE = "sqlite3"
MYSQL = "mysql" MYSQL = "mysql"
MYMYSQL = "mymysql" MYMYSQL = "mymysql"
) )
type Engine struct { type Engine struct {
Mapper IMapper Mapper IMapper
Protocol string DriverName string
UserName string DataSourceName string
Password string
Host string
Port int
DBName string
Charset string
Others string
Tables map[reflect.Type]Table Tables map[reflect.Type]Table
AutoIncrement string AutoIncrement string
ShowSQL bool ShowSQL bool
@ -45,29 +39,8 @@ func StructName(v reflect.Type) string {
return v.Name() return v.Name()
} }
func (e *Engine) OpenDB() (db *sql.DB, err error) { func (e *Engine) OpenDB() (*sql.DB, error) {
db = nil return sql.Open(e.DriverName, e.DataSourceName)
err = nil
if e.Protocol == SQLITE {
// 'sqlite:///foo.db'
db, err = sql.Open("sqlite3", e.Others)
// 'sqlite:///:memory:'
} else if e.Protocol == MYSQL {
// 'mysql://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
connstr := strings.Join([]string{e.UserName, ":",
e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "")
db, err = sql.Open(e.Protocol, connstr)
} else if e.Protocol == MYMYSQL {
// DBNAME/USER/PASSWD
connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/")
db, err = sql.Open(e.Protocol, connstr)
// unix:SOCKPATH*DBNAME/USER/PASSWD
// unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD
// tcp:ADDR*DBNAME/USER/PASSWD
// tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD
}
return
} }
func (engine *Engine) MakeSession() (session Session, err error) { func (engine *Engine) MakeSession() (session Session, err error) {
@ -75,16 +48,8 @@ func (engine *Engine) MakeSession() (session Session, err error) {
if err != nil { if err != nil {
return Session{}, err return Session{}, err
} }
if engine.Protocol == PQSQL {
engine.QuoteIdentifier = "\""
session = Session{Engine: engine, Db: db} session = Session{Engine: engine, Db: db}
} else if engine.Protocol == MSSQL {
engine.QuoteIdentifier = ""
session = Session{Engine: engine, Db: db}
} else {
engine.QuoteIdentifier = "`"
session = Session{Engine: engine, Db: db}
}
session.Init() session.Init()
return return
} }
@ -94,6 +59,11 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Engine {
return engine return engine
} }
func (engine *Engine) Id(id int) *Engine {
engine.Statement.Id(id)
return engine
}
func (engine *Engine) Limit(limit int, start ...int) *Engine { func (engine *Engine) Limit(limit int, start ...int) *Engine {
engine.Statement.Limit(limit, start...) engine.Statement.Limit(limit, start...)
return engine return engine
@ -125,12 +95,12 @@ func (e *Engine) genColumnStr(col *Column) string {
if col.SQLType == Date { if col.SQLType == Date {
sql += " datetime " sql += " datetime "
} else { } else {
if e.Protocol == SQLITE && col.IsPrimaryKey { if e.DriverName == SQLITE && col.IsPrimaryKey {
sql += "integer" sql += "integer"
} else { } else {
sql += col.SQLType.Name sql += col.SQLType.Name
} }
if e.Protocol != SQLITE { if e.DriverName != SQLITE {
if col.SQLType != Decimal { if col.SQLType != Decimal {
sql += "(" + strconv.Itoa(col.Length) + ")" sql += "(" + strconv.Itoa(col.Length) + ")"
} else { } else {
@ -165,17 +135,11 @@ func (e *Engine) genCreateSQL(table *Table) string {
sql += "," sql += ","
} }
sql = sql[:len(sql)-2] + ");" sql = sql[:len(sql)-2] + ");"
if e.ShowSQL {
fmt.Println(sql)
}
return sql return sql
} }
func (e *Engine) genDropSQL(table *Table) string { func (e *Engine) genDropSQL(table *Table) string {
sql := "DROP TABLE IF EXISTS `" + table.Name + "`;" sql := "DROP TABLE IF EXISTS `" + table.Name + "`;"
if e.ShowSQL {
fmt.Println(sql)
}
return sql return sql
} }

View File

@ -32,6 +32,11 @@ func (session *Session) Where(querystring string, args ...interface{}) *Session
return session return session
} }
func (session *Session) Id(id int) *Session {
session.Statement.Id(id)
return session
}
func (session *Session) Limit(limit int, start ...int) *Session { func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...) session.Statement.Limit(limit, start...)
return session return session
@ -171,6 +176,9 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result,
} }
func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) {
if session.Statement.Table != nil && session.Statement.Table.PrimaryKey != "" {
sql = strings.Replace(sql, "(id)", session.Statement.Table.PrimaryKey, -1)
}
if session.Engine.ShowSQL { if session.Engine.ShowSQL {
fmt.Println(sql) fmt.Println(sql)
} }
@ -272,11 +280,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
return nil return nil
} }
func (session *Session) Query(sqls string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
if session.Engine.ShowSQL { if session.Statement.Table != nil && session.Statement.Table.PrimaryKey != "" {
fmt.Println(sqls) sql = strings.Replace(sql, "(id)", session.Statement.Table.PrimaryKey, -1)
} }
s, err := session.Db.Prepare(sqls) if session.Engine.ShowSQL {
fmt.Println(sql)
}
s, err := session.Db.Prepare(sql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -354,7 +365,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) {
func (session *Session) InsertOne(bean interface{}) (int64, error) { func (session *Session) InsertOne(bean interface{}) (int64, error) {
table := session.Engine.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
session.Statement.Table = table
colNames := make([]string, 0) colNames := make([]string, 0)
colPlaces := make([]string, 0) colPlaces := make([]string, 0)
var args = make([]interface{}, 0) var args = make([]interface{}, 0)
@ -425,6 +436,7 @@ func (session *Session) BuildConditions(table *Table, bean interface{}) ([]strin
func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) {
table := session.Engine.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
session.Statement.Table = table
colNames, args := session.BuildConditions(table, bean) colNames, args := session.BuildConditions(table, bean)
var condiColNames []string var condiColNames []string
var condiArgs []interface{} var condiArgs []interface{}
@ -473,6 +485,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
func (session *Session) Delete(bean interface{}) (int64, error) { func (session *Session) Delete(bean interface{}) (int64, error) {
table := session.Engine.Bean2Table(bean) table := session.Engine.Bean2Table(bean)
session.Statement.Table = table
colNames, args := session.BuildConditions(table, bean) colNames, args := session.BuildConditions(table, bean)
var condition = "" var condition = ""

View File

@ -38,6 +38,16 @@ func (statement *Statement) Where(querystring string, args ...interface{}) {
statement.Params = args statement.Params = args
} }
func (statement *Statement) Id(id int) {
if statement.WhereStr == "" {
statement.WhereStr = "(id)=?"
statement.Params = []interface{}{id}
} else {
statement.WhereStr = statement.WhereStr + " and (id)=?"
statement.Params = append(statement.Params, id)
}
}
func (statement *Statement) Limit(limit int, start ...int) { func (statement *Statement) Limit(limit int, start ...int) {
statement.LimitN = limit statement.LimitN = limit
if len(start) > 0 { if len(start) > 0 {
@ -76,7 +86,7 @@ func (statement Statement) genCountSql() string {
} }
func (statement Statement) genSelectSql(columnStr string) (a string) { func (statement Statement) genSelectSql(columnStr string) (a string) {
if statement.Engine.Protocol == "mssql" { if statement.Engine.DriverName == MSSQL {
if statement.Start > 0 { if statement.Start > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
statement.Table.PKColumn().Name, statement.Table.PKColumn().Name,

46
xorm.go
View File

@ -2,48 +2,26 @@ package xorm
import ( import (
"reflect" "reflect"
"strings"
) )
// 'sqlite:///foo.db' func Create(driverName string, dataSourceName string) Engine {
// 'sqlite:////Uses/lunny/foo.db' engine := Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{},
// 'sqlite:///:memory:' DataSourceName: dataSourceName}
// '<protocol>://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
func Create(schema string) Engine {
engine := Engine{}
engine.Mapper = SnakeMapper{}
engine.Tables = make(map[reflect.Type]Table) engine.Tables = make(map[reflect.Type]Table)
engine.Statement.Engine = &engine engine.Statement.Engine = &engine
l := strings.Split(schema, "://") if driverName == SQLITE {
if len(l) == 2 {
engine.Protocol = l[0]
if l[0] == "sqlite" {
engine.Charset = "utf8"
engine.AutoIncrement = "AUTOINCREMENT" engine.AutoIncrement = "AUTOINCREMENT"
if l[1] == "/:memory:" {
engine.Others = l[1]
} else if strings.Index(l[1], "//") == 0 {
engine.Others = l[1][1:]
} else if strings.Index(l[1], "/") == 0 {
engine.Others = "." + l[1]
}
} else { } else {
engine.AutoIncrement = "AUTO_INCREMENT" engine.AutoIncrement = "AUTO_INCREMENT"
x := strings.Split(l[1], ":")
engine.UserName = x[0]
y := strings.Split(x[1], "@")
engine.Password = y[0]
z := strings.Split(y[1], "/")
engine.Host = z[0]
a := strings.Split(z[1], "?")
engine.DBName = a[0]
if len(a) == 2 {
engine.Charset = strings.Split(a[1], "=")[1]
} else {
engine.Charset = "utf8"
}
}
} }
if engine.DriverName == PQSQL {
engine.QuoteIdentifier = "\""
} else if engine.DriverName == MSSQL {
engine.QuoteIdentifier = ""
} else {
engine.QuoteIdentifier = "`"
}
return engine return engine
} }

View File

@ -132,8 +132,13 @@ func insertTwoTable(t *testing.T) {
func update(t *testing.T) { func update(t *testing.T) {
// update by id // update by id
user := Userinfo{Username: "xxx"} user := Userinfo{Username: "xxx"}
condiUser := Userinfo{Uid: 1} _, err := engine.Id(1).Update(&user)
_, err := engine.Update(&user, &condiUser) if err != nil {
t.Error(err)
return
}
_, err = engine.Update(&Userinfo{Username: "yyy"}, &user)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -320,7 +325,7 @@ func combineTransaction(t *testing.T) {
} }
func TestMysql(t *testing.T) { func TestMysql(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test") engine = xorm.Create("mysql", "root:123@/test?charset=utf8")
engine.ShowSQL = true engine.ShowSQL = true
directCreateTable(t) directCreateTable(t)
@ -345,7 +350,7 @@ func TestMysql(t *testing.T) {
} }
func TestSqlite(t *testing.T) { func TestSqlite(t *testing.T) {
engine = xorm.Create("sqlite:///test.db") engine = xorm.Create("sqlite3", "./test.db")
engine.ShowSQL = true engine.ShowSQL = true
directCreateTable(t) directCreateTable(t)