init project

This commit is contained in:
Lunny Xiao 2013-05-03 15:26:51 +08:00
commit 4969e8bf94
9 changed files with 1666 additions and 0 deletions

22
.gitignore vendored Normal file
View File

@ -0,0 +1,22 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe

127
README.md Normal file
View File

@ -0,0 +1,127 @@
orm
=====
orm is an ORM for Go. It lets you map Go structs to tables in a database.
Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future.
All in all, it's not entirely ready for advanced use yet, but it's getting there.
Drivers for Go's sql package which support database/sql includes:
Mysql:[github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv)
Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL)
SQLite:[github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
### Installing xorm
go get github.com/lunny/xorm
### Quick Start
1. Create an database engine (for example: mysql)
```go
engine := xorm.Create("mysql://root:123@localhost/test")
```
2. Define your struct
```go
type User struct {
Id int
Name string
Age int `xorm:"-"`
}
```
for Simple Task, just use engine's functions:
begin start, you should create a database and then we create the tables
```go
err := engine.CreateTables(&User{})
```
then, insert an struct to table
```go
id, err := engine.Insert(&User{Name:"lunny"})
```
or you want to update this struct
```go
user := User{Id:1, Name:"xlw"}
rows, err := engine.Update(&user)
```
3. Fetch a single object by user
```go
var user = User{Id:27}
engine.Get(&user)
var user = User{Name:"xlw"}
engine.Get(&user)
```
for deep use, you should create a session, this func will create a connection to db
```go
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return
}
```
1. Fetch a single object by where
```go
var user Userinfo
session.Where("id=?", 27).Get(&user)
var user2 Userinfo
session.Where(3).Get(&user2) // this is shorthand for the version above
var user3 Userinfo
session.Where("name = ?", "john").Get(&user3) // more complex query
var user4 Userinfo
session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex
```
2. Fetch multiple objects
```go
var allusers []Userinfo
err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20
var tenusers []Userinfo
err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0
var everyone []Userinfo
err := session.Find(&everyone)
```
###***About Map Rules***
1. Struct and struct's fields name should be Pascal style, and the table and column's name default is us
for example:
The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname.
If the keyname is 'UserName' will turn into the select colum 'user_name'
2. You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper
another is use field tag, field tag support the below keywords:
[name] column name
pk the field is a primary key
int(11)/varchar(50) column type
autoincr auto incrment
[not ]null if column can be null value
unique unique
- this field is not map as a table column
## LICENSE
BSD License
[http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/)

435
engine.go Normal file
View File

@ -0,0 +1,435 @@
package xorm
import (
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
type SQLType struct {
Name string
DefaultLength int
}
var (
Int = SQLType{"int", 11}
Char = SQLType{"char", 1}
Varchar = SQLType{"varchar", 50}
Date = SQLType{"date", 24}
Decimal = SQLType{"decimal", 26}
Float = SQLType{"float", 31}
Double = SQLType{"double", 31}
)
const (
PQSQL = "pqsql"
MSSQL = "mssql"
SQLITE = "sqlite"
MYSQL = "mysql"
MYMYSQL = "mymysql"
)
type Column struct {
Name string
FieldName string
SQLType SQLType
Length int
Nullable bool
Default string
IsUnique bool
IsPrimaryKey bool
AutoIncrement bool
}
type Table struct {
Name string
Type reflect.Type
Columns map[string]Column
PrimaryKey string
}
func (table *Table) ColumnStr() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, col.Name)
}
return strings.Join(colNames, ", ")
}
func (table *Table) PlaceHolders() string {
colNames := make([]string, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
colNames = append(colNames, "?")
}
return strings.Join(colNames, ", ")
}
func (table *Table) PKColumn() Column {
return table.Columns[table.PrimaryKey]
}
type Engine struct {
Mapper IMapper
Protocol string
UserName string
Password string
Host string
Port int
DBName string
Charset string
Others string
Tables map[string]Table
AutoIncrement string
ShowSQL bool
QuoteIdentifier string
}
func (e *Engine) OpenDB() (db *sql.DB, err error) {
db = nil
err = nil
if e.Protocol == "sqlite" {
// 'sqlite:///foo.db'
db, err = sql.Open("sqlite3", e.Others)
// 'sqlite:///:memory:'
} else if e.Protocol == "mysql" {
// 'mysql://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
connstr := strings.Join([]string{e.UserName, ":",
e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "")
db, err = sql.Open(e.Protocol, connstr)
} else if e.Protocol == "mymysql" {
// DBNAME/USER/PASSWD
connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/")
db, err = sql.Open(e.Protocol, connstr)
// unix:SOCKPATH*DBNAME/USER/PASSWD
// unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD
// tcp:ADDR*DBNAME/USER/PASSWD
// tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD
}
return
}
func (engine *Engine) MakeSession() (session Session, err error) {
db, err := engine.OpenDB()
if err != nil {
return Session{}, err
}
if engine.Protocol == "pgsql" {
engine.QuoteIdentifier = "\""
session = Session{Engine: engine, Db: db, ParamIteration: 1}
} else if engine.Protocol == "mssql" {
engine.QuoteIdentifier = ""
session = Session{Engine: engine, Db: db, ParamIteration: 1}
} else {
engine.QuoteIdentifier = "`"
session = Session{Engine: engine, Db: db, ParamIteration: 1}
}
session.Mapper = engine.Mapper
session.Init()
return
}
func (sqlType SQLType) genSQL(length int) string {
if sqlType == Date {
return " datetime "
}
return sqlType.Name + "(" + strconv.Itoa(length) + ")"
}
func (e *Engine) genCreateSQL(table *Table) string {
sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` ("
//fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey))
for _, col := range table.Columns {
if col.Name != "" {
sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " "
if col.Nullable {
sql += " NULL "
} else {
sql += " NOT NULL "
}
//fmt.Println(key)
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.AutoIncrement {
sql += e.AutoIncrement + " "
}
if col.IsUnique {
sql += "Unique "
}
sql += ","
}
}
sql = sql[:len(sql)-2] + ");"
if e.ShowSQL {
fmt.Println(sql)
}
return sql
}
func (e *Engine) genDropSQL(table *Table) string {
sql := "DROP TABLE IF EXISTS `" + table.Name + "`;"
if e.ShowSQL {
fmt.Println(sql)
}
return sql
}
func Type(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
return reflect.TypeOf(sliceValue.Interface())
}
func Type2SQLType(t reflect.Type) (st SQLType) {
switch k := t.Kind(); k {
case reflect.Int, reflect.Int32, reflect.Int64:
st = Int
case reflect.String:
st = Varchar
case reflect.Struct:
if t == reflect.TypeOf(time.Time{}) {
st = Date
}
default:
st = Varchar
}
return
}
/*
map an object into a table object
*/
func (engine *Engine) MapOne(bean interface{}) Table {
t := Type(bean)
return engine.MapType(t)
}
func (engine *Engine) MapType(t reflect.Type) Table {
table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t}
table.Columns = make(map[string]Column)
var pkCol *Column = nil
var pkstr = ""
for i := 0; i < t.NumField(); i++ {
tag := t.Field(i).Tag
ormTagStr := tag.Get("xorm")
var col Column
fieldType := t.Field(i).Type
if ormTagStr != "" {
col = Column{FieldName: t.Field(i).Name}
ormTagStr = strings.ToLower(ormTagStr)
tags := strings.Split(ormTagStr, " ")
// TODO:
if len(tags) > 0 {
if tags[0] == "-" {
continue
}
for j, key := range tags {
switch k := strings.ToLower(key); k {
case "pk":
col.IsPrimaryKey = true
pkCol = &col
case "null":
col.Nullable = (tags[j-1] != "not")
case "autoincr":
col.AutoIncrement = true
case "default":
col.Default = tags[j+1]
case "int":
col.SQLType = Int
case "not":
default:
col.Name = k
}
}
if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType)
col.Length = col.SQLType.DefaultLength
}
if col.Name == "" {
col.Name = engine.Mapper.Obj2Table(t.Field(i).Name)
}
}
}
if col.Name == "" {
sqlType := Type2SQLType(fieldType)
col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, true, "", false, false, false}
}
table.Columns[col.Name] = col
if strings.ToLower(t.Field(i).Name) == "id" {
pkstr = col.Name
}
}
if pkCol == nil {
if pkstr != "" {
col := table.Columns[pkstr]
col.IsPrimaryKey = true
col.AutoIncrement = true
col.Nullable = false
col.Length = Int.DefaultLength
table.PrimaryKey = col.Name
}
} else {
table.PrimaryKey = pkCol.Name
}
return table
}
func (engine *Engine) Map(beans ...interface{}) (e error) {
for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; !ok {
table := engine.MapOne(bean)
engine.Tables[table.Name] = table
}
}
return
}
func (engine *Engine) UnMap(beans ...interface{}) (e error) {
for _, bean := range beans {
//t := getBeanType(bean)
tableName := engine.Mapper.Obj2Table(StructName(bean))
if _, ok := engine.Tables[tableName]; ok {
delete(engine.Tables, tableName)
}
}
return
}
func (e *Engine) DropAll() error {
session, err := e.MakeSession()
session.Begin()
defer session.Close()
if err != nil {
return err
}
for _, table := range e.Tables {
sql := e.genDropSQL(&table)
_, err = session.Exec(sql)
if err != nil {
session.Rollback()
break
}
}
session.Commit()
return err
}
func (e *Engine) CreateTables(beans ...interface{}) error {
session, err := e.MakeSession()
session.Begin()
defer session.Close()
if err != nil {
return err
}
for _, bean := range beans {
table := e.MapOne(bean)
e.Tables[table.Name] = table
sql := e.genCreateSQL(&table)
_, err = session.Exec(sql)
if err != nil {
session.Rollback()
break
}
}
session.Commit()
return err
}
func (e *Engine) CreateAll() error {
session, err := e.MakeSession()
session.Begin()
defer session.Close()
if err != nil {
return err
}
for _, table := range e.Tables {
sql := e.genCreateSQL(&table)
_, err = session.Exec(sql)
if err != nil {
session.Rollback()
break
}
}
session.Commit()
return err
}
func (engine *Engine) Insert(beans ...interface{}) (int64, error) {
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return -1, err
}
return session.Insert(beans...)
}
func (engine *Engine) Update(bean interface{}) (int64, error) {
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return -1, err
}
return session.Update(bean)
}
func (engine *Engine) Delete(bean interface{}) (int64, error) {
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return -1, err
}
return session.Delete(bean)
}
func (engine *Engine) Get(bean interface{}) error {
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return err
}
return session.Get(bean)
}
func (engine *Engine) Find(beans interface{}) error {
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return err
}
return session.Find(beans)
}
func (engine *Engine) Count(bean interface{}) (int64, error) {
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
return 0, err
}
return session.Count(bean)
}

20
install Executable file
View File

@ -0,0 +1,20 @@
#!/usr/bin/env bash
if [ ! -f install ]; then
echo 'install must be run within its container folder' 1>&2
exit 1
fi
CURDIR=`pwd`
NEWPATH="$GOPATH/src/${PWD##*/}"
if [ ! -d "$NEWPATH" ]; then
ln -s $CURDIR $NEWPATH
fi
gofmt -w $CURDIR
cd $NEWPATH
go install ${PWD##*/}
cd $CURDIR
echo 'finished'

80
mapper.go Normal file
View File

@ -0,0 +1,80 @@
package xorm
import (
//"reflect"
//"strings"
)
type IMapper interface {
Obj2Table(string) string
Table2Obj(string) string
}
type SnakeMapper struct {
}
func snakeCasedName(name string) string {
newstr := make([]rune, 0)
firstTime := true
for _, chr := range name {
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
if firstTime == true {
firstTime = false
} else {
newstr = append(newstr, '_')
}
chr -= ('A' - 'a')
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func Pascal2Sql(s string) (d string) {
d = ""
lastIdx := 0
for i := 0; i < len(s); i++ {
if s[i] >= 'A' && s[i] <= 'Z' {
if lastIdx < i {
d += s[lastIdx+1 : i]
}
if i != 0 {
d += "_"
}
d += string(s[i] + 32)
lastIdx = i
}
}
d += s[lastIdx+1:]
return
}
func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name)
}
func titleCasedName(name string) string {
newstr := make([]rune, 0)
upNextChar := true
for _, chr := range name {
switch {
case upNextChar:
upNextChar = false
chr -= ('a' - 'A')
case chr == '_':
upNextChar = true
continue
}
newstr = append(newstr, chr)
}
return string(newstr)
}
func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name)
}

601
session.go Normal file
View File

@ -0,0 +1,601 @@
package xorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
func getTypeName(obj interface{}) (typestr string) {
typ := reflect.TypeOf(obj)
typestr = typ.String()
lastDotIndex := strings.LastIndex(typestr, ".")
if lastDotIndex != -1 {
typestr = typestr[lastDotIndex+1:]
}
return
}
func StructName(s interface{}) string {
v := reflect.TypeOf(s)
return Type2StructName(v)
}
func Type2StructName(v reflect.Type) string {
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v.Name()
}
type Session struct {
Db *sql.DB
Engine *Engine
Statements []Statement
Mapper IMapper
AutoCommit bool
ParamIteration int
CurStatementIdx int
}
func (session *Session) Init() {
session.Statements = make([]Statement, 0)
/*session.Statement.TableName = ""
session.Statement.LimitStr = 0
session.Statement.OffsetStr = 0
session.Statement.WhereStr = ""
session.Statement.ParamStr = make([]interface{}, 0)
session.Statement.OrderStr = ""
session.Statement.JoinStr = ""
session.Statement.GroupByStr = ""
session.Statement.HavingStr = ""*/
session.CurStatementIdx = -1
session.ParamIteration = 1
}
func (session *Session) Close() {
defer session.Db.Close()
}
func (session *Session) CurrentStatement() *Statement {
if session.CurStatementIdx > -1 {
return &session.Statements[session.CurStatementIdx]
}
return nil
}
func (session *Session) AutoStatement() *Statement {
if session.CurStatementIdx == -1 {
session.newStatement()
}
return session.CurrentStatement()
}
//Execute sql
func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) {
rs, err := session.Db.Prepare(finalQueryString)
if err != nil {
return nil, err
}
defer rs.Close()
res, err := rs.Exec(args...)
if err != nil {
return nil, err
}
return res, nil
}
func (session *Session) Where(querystring interface{}, args ...interface{}) *Session {
statement := session.AutoStatement()
switch querystring := querystring.(type) {
case string:
statement.WhereStr = querystring
case int:
if session.Engine.Protocol == "pgsql" {
statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration)
} else {
statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier)
session.ParamIteration++
}
args = append(args, querystring)
}
statement.ParamStr = args
return session
}
func (session *Session) Limit(start int, size ...int) *Session {
session.AutoStatement().LimitStr = start
if len(size) > 0 {
session.CurrentStatement().OffsetStr = size[0]
}
return session
}
func (session *Session) Offset(offset int) *Session {
session.AutoStatement().OffsetStr = offset
return session
}
func (session *Session) OrderBy(order string) *Session {
session.AutoStatement().OrderStr = order
return session
}
//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
func (session *Session) Join(join_operator, tablename, condition string) *Session {
if session.AutoStatement().JoinStr != "" {
session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
} else {
session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition)
}
return session
}
func (session *Session) GroupBy(keys string) *Session {
session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys)
return session
}
func (session *Session) Having(conditions string) *Session {
session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions)
return session
}
func (session *Session) Begin() {
}
func (session *Session) Rollback() {
}
func (session *Session) Commit() {
for _, statement := range session.Statements {
sql := statement.generateSql()
session.Exec(sql)
}
}
func (session *Session) TableName(bean interface{}) string {
return session.Mapper.Obj2Table(StructName(bean))
}
func (session *Session) newStatement() {
state := Statement{}
state.Session = session
session.Statements = append(session.Statements, state)
session.CurStatementIdx = len(session.Statements) - 1
}
func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error {
dataStruct := reflect.Indirect(reflect.ValueOf(obj))
if dataStruct.Kind() != reflect.Struct {
return errors.New("expected a pointer to a struct")
}
tablName := session.TableName(obj)
table := session.Engine.Tables[tablName]
for key, data := range objMap {
structField := dataStruct.FieldByName(table.Columns[key].FieldName)
if !structField.CanSet() {
continue
}
var v interface{}
switch structField.Type().Kind() {
case reflect.Slice:
v = data
case reflect.String:
v = string(data)
case reflect.Bool:
v = string(data) == "1"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
x, err := strconv.Atoi(string(data))
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
v = x
case reflect.Int64:
x, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
v = x
case reflect.Float32, reflect.Float64:
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error())
}
v = x
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
v = x
//Now only support Time type
case reflect.Struct:
if structField.Type().String() != "time.Time" {
return errors.New("unsupported struct type in Scan: " + structField.Type().String())
}
x, err := time.Parse("2006-01-02 15:04:05", string(data))
if err != nil {
x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data))
if err != nil {
return errors.New("unsupported time format: " + string(data))
}
}
v = x
default:
return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String())
}
structField.Set(reflect.ValueOf(v))
}
return nil
}
func (session *Session) Get(output interface{}) error {
statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(output)
table := session.Engine.Tables[tableName]
statement.Table = &table
resultsSlice, err := session.FindMap(statement)
if err != nil {
return err
}
if len(resultsSlice) == 0 {
return nil
} else if len(resultsSlice) == 1 {
results := resultsSlice[0]
err := session.scanMapIntoStruct(output, results)
if err != nil {
return err
}
} else {
return errors.New("More than one record")
}
return nil
}
func (session *Session) Count(bean interface{}) (int64, error) {
statement := session.AutoStatement()
session.Limit(1)
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
statement.Table = &table
resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr)
if err != nil {
return 0, err
}
var total int64 = 0
for _, results := range resultsSlice {
total, err = strconv.ParseInt(string(results["total"]), 10, 64)
break
}
return int64(total), err
}
func (session *Session) Find(rowsSlicePtr interface{}) error {
statement := session.AutoStatement()
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return errors.New("needs a pointer to a slice")
}
sliceElementType := sliceValue.Type().Elem()
tableName := session.Mapper.Obj2Table(Type2StructName(sliceElementType))
table := session.Engine.Tables[tableName]
statement.Table = &table
resultsSlice, err := session.FindMap(statement)
if err != nil {
return err
}
for _, results := range resultsSlice {
newValue := reflect.New(sliceElementType)
err := session.scanMapIntoStruct(newValue.Interface(), results)
if err != nil {
return err
}
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
}
return nil
}
func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSlice []map[string][]byte, err error) {
if session.Engine.ShowSQL {
fmt.Println(sqls)
}
s, err := session.Db.Prepare(sqls)
if err != nil {
return nil, err
}
defer s.Close()
res, err := s.Query(paramStr...)
if err != nil {
return nil, err
}
defer res.Close()
fields, err := res.Columns()
if err != nil {
return nil, err
}
for res.Next() {
result := make(map[string][]byte)
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers = append(scanResultContainers, &scanResultContainer)
}
if err := res.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
continue
}
aa := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
var str string
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
result[key] = []byte(str)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str)
case reflect.Slice:
if aa.Elem().Kind() == reflect.Uint8 {
result[key] = rawValue.Interface().([]byte)
break
}
case reflect.String:
str = vv.String()
result[key] = []byte(str)
//时间类型
case reflect.Struct:
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str)
}
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) {
sqls := statement.generateSql()
return session.SQL2Map(sqls, statement.ParamStr)
}
func (session *Session) Insert(beans ...interface{}) (int64, error) {
var lastId int64 = -1
for _, bean := range beans {
lastId, err := session.InsertOne(bean)
if err != nil {
return lastId, err
}
}
return lastId, nil
}
func (session *Session) InsertOne(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames := make([]string, 0)
colPlaces := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
val := fieldValue.Interface()
if col.AutoIncrement {
if fieldValue.Int() == 0 {
continue
}
}
args = append(args, val)
colNames = append(colNames, col.Name)
colPlaces = append(colPlaces, "?")
}
statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)",
session.Engine.QuoteIdentifier,
tableName,
session.Engine.QuoteIdentifier,
strings.Join(colNames, ", "),
strings.Join(colPlaces, ", "))
if session.Engine.ShowSQL {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
if err != nil {
return -1, err
}
id, err := res.LastInsertId()
if err != nil {
return -1, err
}
return id, nil
}
func (session *Session) Update(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
switch fieldType.Kind() {
case reflect.String:
if fieldValue.String() == "" {
continue
}
case reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time)
if t.IsZero() {
continue
}
}
default:
continue
}
args = append(args, val)
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
var condition = ""
st := session.CurrentStatement()
if st != nil && st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr)
}
if condition == "" {
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName)
if fieldValue.Int() != 0 {
condition = fmt.Sprintf("WHERE %v = ?", table.PKColumn().Name)
args = append(args, fieldValue.Interface())
}
}
statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v",
session.Engine.QuoteIdentifier,
tableName,
session.Engine.QuoteIdentifier,
strings.Join(colNames, ", "),
condition)
if session.Engine.ShowSQL {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
if err != nil {
return -1, err
}
id, err := res.RowsAffected()
if err != nil {
return -1, err
}
return id, nil
}
func (session *Session) Delete(bean interface{}) (int64, error) {
tableName := session.TableName(bean)
table := session.Engine.Tables[tableName]
colNames := make([]string, 0)
var args = make([]interface{}, 0)
for _, col := range table.Columns {
if col.Name == "" {
continue
}
fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName)
fieldType := reflect.TypeOf(fieldValue.Interface())
val := fieldValue.Interface()
switch fieldType.Kind() {
case reflect.String:
if fieldValue.String() == "" {
continue
}
case reflect.Int, reflect.Int32, reflect.Int64:
if fieldValue.Int() == 0 {
continue
}
case reflect.Struct:
if fieldType == reflect.TypeOf(time.Now()) {
t := fieldValue.Interface().(time.Time)
if t.IsZero() {
continue
}
}
default:
continue
}
args = append(args, val)
colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?")
}
var condition = ""
st := session.CurrentStatement()
if st != nil && st.WhereStr != "" {
condition = fmt.Sprintf("WHERE %v", st.WhereStr)
if len(colNames) > 0 {
condition += " and "
condition += strings.Join(colNames, " and ")
}
} else {
condition = "WHERE " + strings.Join(colNames, " and ")
}
statement := fmt.Sprintf("DELETE FROM %v%v%v %v",
session.Engine.QuoteIdentifier,
tableName,
session.Engine.QuoteIdentifier,
condition)
if session.Engine.ShowSQL {
fmt.Println(statement)
}
res, err := session.Exec(statement, args...)
if err != nil {
return -1, err
}
id, err := res.RowsAffected()
if err != nil {
return -1, err
}
return id, nil
}

144
statement.go Normal file
View File

@ -0,0 +1,144 @@
package xorm
import (
"fmt"
)
type Statement struct {
TableName string
Table *Table
Session *Session
LimitStr int
OffsetStr int
WhereStr string
ParamStr []interface{}
OrderStr string
JoinStr string
GroupByStr string
HavingStr string
}
func (statement *Statement) Limit(start int, size ...int) *Statement {
statement.LimitStr = start
if len(size) > 0 {
statement.OffsetStr = size[0]
}
return statement
}
func (statement *Statement) Offset(offset int) *Statement {
statement.OffsetStr = offset
return statement
}
func (statement *Statement) OrderBy(order string) *Statement {
statement.OrderStr = order
return statement
}
func (statement *Statement) Select(colums string) *Statement {
//statement.ColumnStr = colums
return statement
}
func (statement Statement) generateSql() string {
columnStr := statement.Table.ColumnStr()
return statement.genSelectSql(columnStr)
}
func (statement Statement) genCountSql() string {
return statement.genSelectSql("count(*) as total")
}
func (statement Statement) genSelectSql(columnStr string) (a string) {
session := statement.Session
if session.Engine.Protocol == "mssql" {
if statement.OffsetStr > 0 {
a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v",
statement.Table.PKColumn().Name,
columnStr,
statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
}
a = fmt.Sprintf("select %v from (%v) "+
"as a where rownum between %v and %v",
columnStr,
a,
statement.OffsetStr,
statement.LimitStr)
} else if statement.LimitStr > 0 {
a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
}
if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr)
}
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
} else {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name)
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
}
if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr)
}
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
}
} else {
a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name)
if statement.JoinStr != "" {
a = fmt.Sprintf("%v %v", a, statement.JoinStr)
}
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
}
if statement.GroupByStr != "" {
a = fmt.Sprintf("%v %v", a, statement.GroupByStr)
}
if statement.HavingStr != "" {
a = fmt.Sprintf("%v %v", a, statement.HavingStr)
}
if statement.OrderStr != "" {
a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr)
}
if statement.OffsetStr > 0 {
a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr)
} else if statement.LimitStr > 0 {
a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr)
}
}
return
}
/*func (statement *Statement) genInsertSQL() string {
table = statement.Table
colNames := make([]string, len(table.Columns))
for idx, col := range table.Columns {
if col.Name == "" {
continue
}
colNames[idx] = col.Name
}
return strings.Join(colNames, ", ")
colNames := make([]string, len(table.Columns))
for idx, col := range table.Columns {
if col.Name == "" {
continue
}
colNames[idx] = "?"
}
strings.Join(colNames, ", ")
}*/

47
xorm.go Normal file
View File

@ -0,0 +1,47 @@
package xorm
import (
"strings"
)
// 'sqlite:///foo.db'
// 'sqlite:////Uses/lunny/foo.db'
// 'sqlite:///:memory:'
// '<protocol>://<username>:<passwd>@<host>/<dbname>?charset=<encoding>'
func Create(schema string) Engine {
engine := Engine{}
engine.Mapper = SnakeMapper{}
engine.Tables = make(map[string]Table)
l := strings.Split(schema, "://")
if len(l) == 2 {
engine.Protocol = l[0]
if l[0] == "sqlite" {
engine.Charset = "utf8"
engine.AutoIncrement = "AUTOINCREMENT"
if l[1] == "/:memory:" {
engine.Others = l[1]
} else if strings.Index(l[1], "//") == 0 {
engine.Others = l[1][1:]
} else if strings.Index(l[1], "/") == 0 {
engine.Others = "." + l[1]
}
} else {
engine.AutoIncrement = "AUTO_INCREMENT"
x := strings.Split(l[1], ":")
engine.UserName = x[0]
y := strings.Split(x[1], "@")
engine.Password = y[0]
z := strings.Split(y[1], "/")
engine.Host = z[0]
a := strings.Split(z[1], "?")
engine.DBName = a[0]
if len(a) == 2 {
engine.Charset = strings.Split(a[1], "=")[1]
} else {
engine.Charset = "utf8"
}
}
}
return engine
}

190
xorm_test.go Normal file
View File

@ -0,0 +1,190 @@
package xorm_test
import (
"fmt"
_ "github.com/Go-SQL-Driver/MySQL"
//_ "github.com/ziutek/mymysql/godrv"
//_ "github.com/mattn/go-sqlite3"
"testing"
"time"
"xorm"
)
/*
CREATE TABLE `userinfo` (
`uid` INT(10) NULL AUTO_INCREMENT,
`username` VARCHAR(64) NULL,
`departname` VARCHAR(64) NULL,
`created` DATE NULL,
PRIMARY KEY (`uid`)
);
CREATE TABLE `userdeatail` (
`uid` INT(10) NULL,
`intro` TEXT NULL,
`profile` TEXT NULL,
PRIMARY KEY (`uid`)
);
*/
type Userinfo struct {
Uid int `xorm:"id pk not null autoincr"`
Username string
Departname string
Alias string `xorm:"-"`
Created time.Time
}
var engine xorm.Engine
func TestCreateEngine(t *testing.T) {
engine = xorm.Create("mysql://root:123@localhost/test")
//engine = orm.Create("mymysql://root:123@localhost/test")
//engine = orm.Create("sqlite:///test.db")
engine.ShowSQL = true
}
func TestDirectCreateTable(t *testing.T) {
err := engine.CreateTables(&Userinfo{})
if err != nil {
t.Error(err)
}
}
func TestMapper(t *testing.T) {
err := engine.UnMap(&Userinfo{})
if err != nil {
t.Error(err)
}
err = engine.Map(&Userinfo{})
if err != nil {
t.Error(err)
}
err = engine.DropAll()
if err != nil {
t.Error(err)
}
err = engine.CreateAll()
if err != nil {
t.Error(err)
}
}
func TestInsert(t *testing.T) {
user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()}
_, err := engine.Insert(&user)
if err != nil {
t.Error(err)
}
}
func TestInsertAutoIncr(t *testing.T) {
// auto increment insert
user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()}
_, err := engine.Insert(&user)
if err != nil {
t.Error(err)
}
}
func TestInsertMulti(t *testing.T) {
user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()}
user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()}
_, err := engine.Insert(&user1, &user2)
if err != nil {
t.Error(err)
}
}
func TestUpdate(t *testing.T) {
// update by id
user := Userinfo{Uid: 1, Username: "xxx"}
_, err := engine.Update(&user)
if err != nil {
t.Error(err)
}
}
func TestDelete(t *testing.T) {
user := Userinfo{Uid: 1}
_, err := engine.Delete(&user)
if err != nil {
t.Error(err)
}
}
func TestGet(t *testing.T) {
user := Userinfo{Uid: 2}
err := engine.Get(&user)
if err != nil {
t.Error(err)
}
fmt.Println(user)
}
func TestFind(t *testing.T) {
users := make([]Userinfo, 0)
err := engine.Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestCount(t *testing.T) {
user := Userinfo{}
total, err := engine.Count(&user)
if err != nil {
t.Error(err)
}
fmt.Printf("Total %d records!!!", total)
}
func TestWhere(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Where("id > ?", 2).Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestLimit(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.Limit(2, 1).Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestOrder(t *testing.T) {
users := make([]Userinfo, 0)
session, err := engine.MakeSession()
defer session.Close()
if err != nil {
t.Error(err)
}
err = session.OrderBy("id desc").Find(&users)
if err != nil {
t.Error(err)
}
fmt.Println(users)
}
func TestTransaction(*testing.T) {
}