From 09848afbcf8be287412ce5f2d11e0986d60def48 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 9 May 2013 09:56:58 +0800 Subject: [PATCH] use default connection params --- README.md | 12 ++++++---- README_CN.md | 8 +++++-- engine.go | 66 ++++++++++++---------------------------------------- session.go | 23 ++++++++++++++---- statement.go | 12 +++++++++- xorm.go | 52 ++++++++++++----------------------------- xorm_test.go | 13 +++++++---- 7 files changed, 82 insertions(+), 104 deletions(-) diff --git a/README.md b/README.md index 7c4f932f..0a88ce22 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ # xorm =========== -[中文](README_CN.md) +[中文](./README_CN.md) xorm is an ORM for Go. It lets you map Go structs to tables in a database. -Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. +Right now, it supports Mysql and SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. All in all, it's not entirely ready for product use yet, but it's getting there. @@ -21,9 +21,13 @@ SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) ## Quick Start -1.Create a database engine (for example: mysql) +1.Create a database engine just like sql.OpenDB (for example: mysql) - engine := xorm.Create("mysql://root:123@localhost/test") + engine := xorm.Create("mysql", "root:123@/test?charset=utf8") + +or + + engine = xorm.Create("sqlite3", "./test.db") 2.Define your struct diff --git a/README_CN.md b/README_CN.md index 7fc5a164..dcdaa212 100644 --- a/README_CN.md +++ b/README_CN.md @@ -21,9 +21,13 @@ SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) ## 快速开始 -1.创建数据库引擎 (例如: mysql) +1.创建数据库引擎,这个函数的参数和sql.OpenDB相同,但不会立即创建连接 (例如: mysql) - engine := xorm.Create("mysql://root:123@localhost/test") + engine := xorm.Create("mysql", "root:123@/test?charset=utf8") + +or + + engine = xorm.Create("sqlite3", "./test.db") 2.定义你的Struct diff --git a/engine.go b/engine.go index 84cff18a..8b60ba9b 100644 --- a/engine.go +++ b/engine.go @@ -2,7 +2,7 @@ package xorm import ( "database/sql" - "fmt" + //"fmt" "reflect" "strconv" "strings" @@ -11,21 +11,15 @@ import ( const ( PQSQL = "pqsql" MSSQL = "mssql" - SQLITE = "sqlite" + SQLITE = "sqlite3" MYSQL = "mysql" MYMYSQL = "mymysql" ) type Engine struct { Mapper IMapper - Protocol string - UserName string - Password string - Host string - Port int - DBName string - Charset string - Others string + DriverName string + DataSourceName string Tables map[reflect.Type]Table AutoIncrement string ShowSQL bool @@ -45,29 +39,8 @@ func StructName(v reflect.Type) string { return v.Name() } -func (e *Engine) OpenDB() (db *sql.DB, err error) { - db = nil - err = nil - if e.Protocol == SQLITE { - // 'sqlite:///foo.db' - db, err = sql.Open("sqlite3", e.Others) - // 'sqlite:///:memory:' - } else if e.Protocol == MYSQL { - // 'mysql://:@/?charset=' - connstr := strings.Join([]string{e.UserName, ":", - e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "") - db, err = sql.Open(e.Protocol, connstr) - } else if e.Protocol == MYMYSQL { - // DBNAME/USER/PASSWD - connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/") - db, err = sql.Open(e.Protocol, connstr) - // unix:SOCKPATH*DBNAME/USER/PASSWD - // unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD - // tcp:ADDR*DBNAME/USER/PASSWD - // tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD - } - - return +func (e *Engine) OpenDB() (*sql.DB, error) { + return sql.Open(e.DriverName, e.DataSourceName) } func (engine *Engine) MakeSession() (session Session, err error) { @@ -75,16 +48,8 @@ func (engine *Engine) MakeSession() (session Session, err error) { if err != nil { return Session{}, err } - if engine.Protocol == PQSQL { - engine.QuoteIdentifier = "\"" - session = Session{Engine: engine, Db: db} - } else if engine.Protocol == MSSQL { - engine.QuoteIdentifier = "" - session = Session{Engine: engine, Db: db} - } else { - engine.QuoteIdentifier = "`" - session = Session{Engine: engine, Db: db} - } + + session = Session{Engine: engine, Db: db} session.Init() return } @@ -94,6 +59,11 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Engine { return engine } +func (engine *Engine) Id(id int) *Engine { + engine.Statement.Id(id) + return engine +} + func (engine *Engine) Limit(limit int, start ...int) *Engine { engine.Statement.Limit(limit, start...) return engine @@ -125,12 +95,12 @@ func (e *Engine) genColumnStr(col *Column) string { if col.SQLType == Date { sql += " datetime " } else { - if e.Protocol == SQLITE && col.IsPrimaryKey { + if e.DriverName == SQLITE && col.IsPrimaryKey { sql += "integer" } else { sql += col.SQLType.Name } - if e.Protocol != SQLITE { + if e.DriverName != SQLITE { if col.SQLType != Decimal { sql += "(" + strconv.Itoa(col.Length) + ")" } else { @@ -165,17 +135,11 @@ func (e *Engine) genCreateSQL(table *Table) string { sql += "," } sql = sql[:len(sql)-2] + ");" - if e.ShowSQL { - fmt.Println(sql) - } return sql } func (e *Engine) genDropSQL(table *Table) string { sql := "DROP TABLE IF EXISTS `" + table.Name + "`;" - if e.ShowSQL { - fmt.Println(sql) - } return sql } diff --git a/session.go b/session.go index 11852f8d..2cc6b683 100644 --- a/session.go +++ b/session.go @@ -32,6 +32,11 @@ func (session *Session) Where(querystring string, args ...interface{}) *Session return session } +func (session *Session) Id(id int) *Session { + session.Statement.Id(id) + return session +} + func (session *Session) Limit(limit int, start ...int) *Session { session.Statement.Limit(limit, start...) return session @@ -171,6 +176,9 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, } func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { + if session.Statement.Table != nil && session.Statement.Table.PrimaryKey != "" { + sql = strings.Replace(sql, "(id)", session.Statement.Table.PrimaryKey, -1) + } if session.Engine.ShowSQL { fmt.Println(sql) } @@ -272,11 +280,14 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } -func (session *Session) Query(sqls string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - if session.Engine.ShowSQL { - fmt.Println(sqls) +func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { + if session.Statement.Table != nil && session.Statement.Table.PrimaryKey != "" { + sql = strings.Replace(sql, "(id)", session.Statement.Table.PrimaryKey, -1) } - s, err := session.Db.Prepare(sqls) + if session.Engine.ShowSQL { + fmt.Println(sql) + } + s, err := session.Db.Prepare(sql) if err != nil { return nil, err } @@ -354,7 +365,7 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { func (session *Session) InsertOne(bean interface{}) (int64, error) { table := session.Engine.Bean2Table(bean) - + session.Statement.Table = table colNames := make([]string, 0) colPlaces := make([]string, 0) var args = make([]interface{}, 0) @@ -425,6 +436,7 @@ func (session *Session) BuildConditions(table *Table, bean interface{}) ([]strin func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { table := session.Engine.Bean2Table(bean) + session.Statement.Table = table colNames, args := session.BuildConditions(table, bean) var condiColNames []string var condiArgs []interface{} @@ -473,6 +485,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.Bean2Table(bean) + session.Statement.Table = table colNames, args := session.BuildConditions(table, bean) var condition = "" diff --git a/statement.go b/statement.go index a771cc25..444edf89 100644 --- a/statement.go +++ b/statement.go @@ -38,6 +38,16 @@ func (statement *Statement) Where(querystring string, args ...interface{}) { statement.Params = args } +func (statement *Statement) Id(id int) { + if statement.WhereStr == "" { + statement.WhereStr = "(id)=?" + statement.Params = []interface{}{id} + } else { + statement.WhereStr = statement.WhereStr + " and (id)=?" + statement.Params = append(statement.Params, id) + } +} + func (statement *Statement) Limit(limit int, start ...int) { statement.LimitN = limit if len(start) > 0 { @@ -76,7 +86,7 @@ func (statement Statement) genCountSql() string { } func (statement Statement) genSelectSql(columnStr string) (a string) { - if statement.Engine.Protocol == "mssql" { + if statement.Engine.DriverName == MSSQL { if statement.Start > 0 { a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", statement.Table.PKColumn().Name, diff --git a/xorm.go b/xorm.go index 1e913a45..d37b8fd0 100644 --- a/xorm.go +++ b/xorm.go @@ -2,48 +2,26 @@ package xorm import ( "reflect" - "strings" ) -// 'sqlite:///foo.db' -// 'sqlite:////Uses/lunny/foo.db' -// 'sqlite:///:memory:' -// '://:@/?charset=' -func Create(schema string) Engine { - engine := Engine{} - engine.Mapper = SnakeMapper{} +func Create(driverName string, dataSourceName string) Engine { + engine := Engine{ShowSQL: false, DriverName: driverName, Mapper: SnakeMapper{}, + DataSourceName: dataSourceName} + engine.Tables = make(map[reflect.Type]Table) engine.Statement.Engine = &engine - l := strings.Split(schema, "://") - if len(l) == 2 { - engine.Protocol = l[0] - if l[0] == "sqlite" { - engine.Charset = "utf8" - engine.AutoIncrement = "AUTOINCREMENT" - if l[1] == "/:memory:" { - engine.Others = l[1] - } else if strings.Index(l[1], "//") == 0 { - engine.Others = l[1][1:] - } else if strings.Index(l[1], "/") == 0 { - engine.Others = "." + l[1] - } - } else { - engine.AutoIncrement = "AUTO_INCREMENT" - x := strings.Split(l[1], ":") - engine.UserName = x[0] - y := strings.Split(x[1], "@") - engine.Password = y[0] - z := strings.Split(y[1], "/") - engine.Host = z[0] - a := strings.Split(z[1], "?") - engine.DBName = a[0] - if len(a) == 2 { - engine.Charset = strings.Split(a[1], "=")[1] - } else { - engine.Charset = "utf8" - } - } + if driverName == SQLITE { + engine.AutoIncrement = "AUTOINCREMENT" + } else { + engine.AutoIncrement = "AUTO_INCREMENT" } + if engine.DriverName == PQSQL { + engine.QuoteIdentifier = "\"" + } else if engine.DriverName == MSSQL { + engine.QuoteIdentifier = "" + } else { + engine.QuoteIdentifier = "`" + } return engine } diff --git a/xorm_test.go b/xorm_test.go index 0a3b9131..34af5d6e 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -132,8 +132,13 @@ func insertTwoTable(t *testing.T) { func update(t *testing.T) { // update by id user := Userinfo{Username: "xxx"} - condiUser := Userinfo{Uid: 1} - _, err := engine.Update(&user, &condiUser) + _, err := engine.Id(1).Update(&user) + if err != nil { + t.Error(err) + return + } + + _, err = engine.Update(&Userinfo{Username: "yyy"}, &user) if err != nil { t.Error(err) } @@ -320,7 +325,7 @@ func combineTransaction(t *testing.T) { } func TestMysql(t *testing.T) { - engine = xorm.Create("mysql://root:123@localhost/test") + engine = xorm.Create("mysql", "root:123@/test?charset=utf8") engine.ShowSQL = true directCreateTable(t) @@ -345,7 +350,7 @@ func TestMysql(t *testing.T) { } func TestSqlite(t *testing.T) { - engine = xorm.Create("sqlite:///test.db") + engine = xorm.Create("sqlite3", "./test.db") engine.ShowSQL = true directCreateTable(t)