From 4969e8bf94d4df8b9351c3f62f8724234e8d3b3a Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 3 May 2013 15:26:51 +0800 Subject: [PATCH] init project --- .gitignore | 22 ++ README.md | 127 +++++++++++ engine.go | 435 +++++++++++++++++++++++++++++++++++++ install | 20 ++ mapper.go | 80 +++++++ session.go | 601 +++++++++++++++++++++++++++++++++++++++++++++++++++ statement.go | 144 ++++++++++++ xorm.go | 47 ++++ xorm_test.go | 190 ++++++++++++++++ 9 files changed, 1666 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 engine.go create mode 100755 install create mode 100644 mapper.go create mode 100644 session.go create mode 100644 statement.go create mode 100644 xorm.go create mode 100644 xorm_test.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/README.md b/README.md new file mode 100644 index 00000000..256aa34d --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ +orm +===== + +orm is an ORM for Go. It lets you map Go structs to tables in a database. + +Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. + +All in all, it's not entirely ready for advanced use yet, but it's getting there. + +Drivers for Go's sql package which support database/sql includes: + +Mysql:[github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) + +Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) + +SQLite:[github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + +### Installing xorm + go get github.com/lunny/xorm + +### Quick Start + +1. Create an database engine (for example: mysql) + +```go +engine := xorm.Create("mysql://root:123@localhost/test") +``` + +2. Define your struct + +```go +type User struct { + Id int + Name string + Age int `xorm:"-"` +} +``` + +for Simple Task, just use engine's functions: + +begin start, you should create a database and then we create the tables + +```go +err := engine.CreateTables(&User{}) +``` + +then, insert an struct to table + +```go +id, err := engine.Insert(&User{Name:"lunny"}) +``` + +or you want to update this struct + +```go +user := User{Id:1, Name:"xlw"} +rows, err := engine.Update(&user) +``` + +3. Fetch a single object by user +```go +var user = User{Id:27} +engine.Get(&user) + +var user = User{Name:"xlw"} +engine.Get(&user) +``` + +for deep use, you should create a session, this func will create a connection to db + +```go +session, err := engine.MakeSession() +defer session.Close() +if err != nil { + return +} +``` + +1. Fetch a single object by where +```go +var user Userinfo +session.Where("id=?", 27).Get(&user) + +var user2 Userinfo +session.Where(3).Get(&user2) // this is shorthand for the version above + +var user3 Userinfo +session.Where("name = ?", "john").Get(&user3) // more complex query + +var user4 Userinfo +session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex +``` + +2. Fetch multiple objects + +```go +var allusers []Userinfo +err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 + +var tenusers []Userinfo +err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0 + +var everyone []Userinfo +err := session.Find(&everyone) +``` + +###***About Map Rules*** +1. Struct and struct's fields name should be Pascal style, and the table and column's name default is us +for example: +The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname. +If the keyname is 'UserName' will turn into the select colum 'user_name' + +2. You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper + +another is use field tag, field tag support the below keywords: +[name] column name +pk the field is a primary key +int(11)/varchar(50) column type +autoincr auto incrment +[not ]null if column can be null value +unique unique +- this field is not map as a table column + +## LICENSE + + BSD License + [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) diff --git a/engine.go b/engine.go new file mode 100644 index 00000000..75575a74 --- /dev/null +++ b/engine.go @@ -0,0 +1,435 @@ +package xorm + +import ( + "database/sql" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +type SQLType struct { + Name string + DefaultLength int +} + +var ( + Int = SQLType{"int", 11} + Char = SQLType{"char", 1} + Varchar = SQLType{"varchar", 50} + Date = SQLType{"date", 24} + Decimal = SQLType{"decimal", 26} + Float = SQLType{"float", 31} + Double = SQLType{"double", 31} +) + +const ( + PQSQL = "pqsql" + MSSQL = "mssql" + SQLITE = "sqlite" + MYSQL = "mysql" + MYMYSQL = "mymysql" +) + +type Column struct { + Name string + FieldName string + SQLType SQLType + Length int + Nullable bool + Default string + IsUnique bool + IsPrimaryKey bool + AutoIncrement bool +} + +type Table struct { + Name string + Type reflect.Type + Columns map[string]Column + PrimaryKey string +} + +func (table *Table) ColumnStr() string { + colNames := make([]string, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + colNames = append(colNames, col.Name) + } + return strings.Join(colNames, ", ") +} + +func (table *Table) PlaceHolders() string { + colNames := make([]string, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + colNames = append(colNames, "?") + } + return strings.Join(colNames, ", ") +} + +func (table *Table) PKColumn() Column { + return table.Columns[table.PrimaryKey] +} + +type Engine struct { + Mapper IMapper + Protocol string + UserName string + Password string + Host string + Port int + DBName string + Charset string + Others string + Tables map[string]Table + AutoIncrement string + ShowSQL bool + QuoteIdentifier string +} + +func (e *Engine) OpenDB() (db *sql.DB, err error) { + db = nil + err = nil + if e.Protocol == "sqlite" { + // 'sqlite:///foo.db' + db, err = sql.Open("sqlite3", e.Others) + // 'sqlite:///:memory:' + } else if e.Protocol == "mysql" { + // 'mysql://:@/?charset=' + connstr := strings.Join([]string{e.UserName, ":", + e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "") + db, err = sql.Open(e.Protocol, connstr) + } else if e.Protocol == "mymysql" { + // DBNAME/USER/PASSWD + connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/") + db, err = sql.Open(e.Protocol, connstr) + // unix:SOCKPATH*DBNAME/USER/PASSWD + // unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD + // tcp:ADDR*DBNAME/USER/PASSWD + // tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD + } + + return +} + +func (engine *Engine) MakeSession() (session Session, err error) { + db, err := engine.OpenDB() + if err != nil { + return Session{}, err + } + if engine.Protocol == "pgsql" { + engine.QuoteIdentifier = "\"" + session = Session{Engine: engine, Db: db, ParamIteration: 1} + } else if engine.Protocol == "mssql" { + engine.QuoteIdentifier = "" + session = Session{Engine: engine, Db: db, ParamIteration: 1} + } else { + engine.QuoteIdentifier = "`" + session = Session{Engine: engine, Db: db, ParamIteration: 1} + } + session.Mapper = engine.Mapper + session.Init() + return +} + +func (sqlType SQLType) genSQL(length int) string { + if sqlType == Date { + return " datetime " + } + return sqlType.Name + "(" + strconv.Itoa(length) + ")" +} + +func (e *Engine) genCreateSQL(table *Table) string { + sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` (" + //fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey)) + for _, col := range table.Columns { + if col.Name != "" { + sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " " + if col.Nullable { + sql += " NULL " + } else { + sql += " NOT NULL " + } + //fmt.Println(key) + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + } + if col.AutoIncrement { + sql += e.AutoIncrement + " " + } + if col.IsUnique { + sql += "Unique " + } + sql += "," + } + } + sql = sql[:len(sql)-2] + ");" + if e.ShowSQL { + fmt.Println(sql) + } + return sql +} + +func (e *Engine) genDropSQL(table *Table) string { + sql := "DROP TABLE IF EXISTS `" + table.Name + "`;" + if e.ShowSQL { + fmt.Println(sql) + } + return sql +} + +func Type(bean interface{}) reflect.Type { + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + return reflect.TypeOf(sliceValue.Interface()) +} + +func Type2SQLType(t reflect.Type) (st SQLType) { + switch k := t.Kind(); k { + case reflect.Int, reflect.Int32, reflect.Int64: + st = Int + case reflect.String: + st = Varchar + case reflect.Struct: + if t == reflect.TypeOf(time.Time{}) { + st = Date + } + default: + st = Varchar + } + return +} + +/* +map an object into a table object +*/ +func (engine *Engine) MapOne(bean interface{}) Table { + t := Type(bean) + return engine.MapType(t) +} + +func (engine *Engine) MapType(t reflect.Type) Table { + table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} + table.Columns = make(map[string]Column) + var pkCol *Column = nil + var pkstr = "" + + for i := 0; i < t.NumField(); i++ { + tag := t.Field(i).Tag + ormTagStr := tag.Get("xorm") + var col Column + fieldType := t.Field(i).Type + + if ormTagStr != "" { + col = Column{FieldName: t.Field(i).Name} + ormTagStr = strings.ToLower(ormTagStr) + tags := strings.Split(ormTagStr, " ") + // TODO: + if len(tags) > 0 { + if tags[0] == "-" { + continue + } + for j, key := range tags { + switch k := strings.ToLower(key); k { + case "pk": + col.IsPrimaryKey = true + pkCol = &col + case "null": + col.Nullable = (tags[j-1] != "not") + case "autoincr": + col.AutoIncrement = true + case "default": + col.Default = tags[j+1] + case "int": + col.SQLType = Int + case "not": + default: + col.Name = k + } + } + if col.SQLType.Name == "" { + col.SQLType = Type2SQLType(fieldType) + col.Length = col.SQLType.DefaultLength + } + + if col.Name == "" { + col.Name = engine.Mapper.Obj2Table(t.Field(i).Name) + } + } + } + + if col.Name == "" { + sqlType := Type2SQLType(fieldType) + col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, + sqlType.DefaultLength, true, "", false, false, false} + } + table.Columns[col.Name] = col + if strings.ToLower(t.Field(i).Name) == "id" { + pkstr = col.Name + } + } + + if pkCol == nil { + if pkstr != "" { + col := table.Columns[pkstr] + col.IsPrimaryKey = true + col.AutoIncrement = true + col.Nullable = false + col.Length = Int.DefaultLength + table.PrimaryKey = col.Name + } + } else { + table.PrimaryKey = pkCol.Name + } + + return table +} + +func (engine *Engine) Map(beans ...interface{}) (e error) { + for _, bean := range beans { + //t := getBeanType(bean) + tableName := engine.Mapper.Obj2Table(StructName(bean)) + if _, ok := engine.Tables[tableName]; !ok { + table := engine.MapOne(bean) + engine.Tables[table.Name] = table + } + } + return +} + +func (engine *Engine) UnMap(beans ...interface{}) (e error) { + for _, bean := range beans { + //t := getBeanType(bean) + tableName := engine.Mapper.Obj2Table(StructName(bean)) + if _, ok := engine.Tables[tableName]; ok { + delete(engine.Tables, tableName) + } + } + return +} + +func (e *Engine) DropAll() error { + session, err := e.MakeSession() + session.Begin() + defer session.Close() + if err != nil { + return err + } + + for _, table := range e.Tables { + sql := e.genDropSQL(&table) + _, err = session.Exec(sql) + if err != nil { + session.Rollback() + break + } + } + session.Commit() + return err +} + +func (e *Engine) CreateTables(beans ...interface{}) error { + session, err := e.MakeSession() + session.Begin() + defer session.Close() + if err != nil { + return err + } + for _, bean := range beans { + table := e.MapOne(bean) + e.Tables[table.Name] = table + sql := e.genCreateSQL(&table) + _, err = session.Exec(sql) + if err != nil { + session.Rollback() + break + } + } + session.Commit() + return err +} + +func (e *Engine) CreateAll() error { + session, err := e.MakeSession() + session.Begin() + defer session.Close() + if err != nil { + return err + } + + for _, table := range e.Tables { + sql := e.genCreateSQL(&table) + _, err = session.Exec(sql) + if err != nil { + session.Rollback() + break + } + } + session.Commit() + return err +} + +func (engine *Engine) Insert(beans ...interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return -1, err + } + + return session.Insert(beans...) +} + +func (engine *Engine) Update(bean interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return -1, err + } + + return session.Update(bean) +} + +func (engine *Engine) Delete(bean interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return -1, err + } + + return session.Delete(bean) +} + +func (engine *Engine) Get(bean interface{}) error { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return err + } + + return session.Get(bean) +} + +func (engine *Engine) Find(beans interface{}) error { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return err + } + + return session.Find(beans) +} + +func (engine *Engine) Count(bean interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return 0, err + } + + return session.Count(bean) +} diff --git a/install b/install new file mode 100755 index 00000000..1ccc8fc6 --- /dev/null +++ b/install @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +if [ ! -f install ]; then +echo 'install must be run within its container folder' 1>&2 +exit 1 +fi + +CURDIR=`pwd` +NEWPATH="$GOPATH/src/${PWD##*/}" +if [ ! -d "$NEWPATH" ]; then +ln -s $CURDIR $NEWPATH +fi + +gofmt -w $CURDIR + +cd $NEWPATH +go install ${PWD##*/} +cd $CURDIR + +echo 'finished' diff --git a/mapper.go b/mapper.go new file mode 100644 index 00000000..a00682b5 --- /dev/null +++ b/mapper.go @@ -0,0 +1,80 @@ +package xorm + +import ( +//"reflect" +//"strings" +) + +type IMapper interface { + Obj2Table(string) string + Table2Obj(string) string +} + +type SnakeMapper struct { +} + +func snakeCasedName(name string) string { + newstr := make([]rune, 0) + firstTime := true + + for _, chr := range name { + if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { + if firstTime == true { + firstTime = false + } else { + newstr = append(newstr, '_') + } + chr -= ('A' - 'a') + } + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func Pascal2Sql(s string) (d string) { + d = "" + lastIdx := 0 + for i := 0; i < len(s); i++ { + if s[i] >= 'A' && s[i] <= 'Z' { + if lastIdx < i { + d += s[lastIdx+1 : i] + } + if i != 0 { + d += "_" + } + d += string(s[i] + 32) + lastIdx = i + } + } + d += s[lastIdx+1:] + return +} + +func (mapper SnakeMapper) Obj2Table(name string) string { + return snakeCasedName(name) +} + +func titleCasedName(name string) string { + newstr := make([]rune, 0) + upNextChar := true + + for _, chr := range name { + switch { + case upNextChar: + upNextChar = false + chr -= ('a' - 'A') + case chr == '_': + upNextChar = true + continue + } + + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func (mapper SnakeMapper) Table2Obj(name string) string { + return titleCasedName(name) +} diff --git a/session.go b/session.go new file mode 100644 index 00000000..784f2d86 --- /dev/null +++ b/session.go @@ -0,0 +1,601 @@ +package xorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +func getTypeName(obj interface{}) (typestr string) { + typ := reflect.TypeOf(obj) + typestr = typ.String() + + lastDotIndex := strings.LastIndex(typestr, ".") + if lastDotIndex != -1 { + typestr = typestr[lastDotIndex+1:] + } + + return +} + +func StructName(s interface{}) string { + v := reflect.TypeOf(s) + return Type2StructName(v) +} + +func Type2StructName(v reflect.Type) string { + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Name() +} + +type Session struct { + Db *sql.DB + Engine *Engine + Statements []Statement + Mapper IMapper + AutoCommit bool + ParamIteration int + CurStatementIdx int +} + +func (session *Session) Init() { + session.Statements = make([]Statement, 0) + /*session.Statement.TableName = "" + session.Statement.LimitStr = 0 + session.Statement.OffsetStr = 0 + session.Statement.WhereStr = "" + session.Statement.ParamStr = make([]interface{}, 0) + session.Statement.OrderStr = "" + session.Statement.JoinStr = "" + session.Statement.GroupByStr = "" + session.Statement.HavingStr = ""*/ + session.CurStatementIdx = -1 + + session.ParamIteration = 1 +} + +func (session *Session) Close() { + defer session.Db.Close() +} + +func (session *Session) CurrentStatement() *Statement { + if session.CurStatementIdx > -1 { + return &session.Statements[session.CurStatementIdx] + } + return nil +} + +func (session *Session) AutoStatement() *Statement { + if session.CurStatementIdx == -1 { + session.newStatement() + } + return session.CurrentStatement() +} + +//Execute sql +func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(finalQueryString) + if err != nil { + return nil, err + } + defer rs.Close() + + res, err := rs.Exec(args...) + if err != nil { + return nil, err + } + return res, nil +} + +func (session *Session) Where(querystring interface{}, args ...interface{}) *Session { + statement := session.AutoStatement() + switch querystring := querystring.(type) { + case string: + statement.WhereStr = querystring + case int: + if session.Engine.Protocol == "pgsql" { + statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration) + } else { + statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier) + session.ParamIteration++ + } + args = append(args, querystring) + } + statement.ParamStr = args + return session +} + +func (session *Session) Limit(start int, size ...int) *Session { + session.AutoStatement().LimitStr = start + if len(size) > 0 { + session.CurrentStatement().OffsetStr = size[0] + } + return session +} + +func (session *Session) Offset(offset int) *Session { + session.AutoStatement().OffsetStr = offset + return session +} + +func (session *Session) OrderBy(order string) *Session { + session.AutoStatement().OrderStr = order + return session +} + +//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (session *Session) Join(join_operator, tablename, condition string) *Session { + if session.AutoStatement().JoinStr != "" { + session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } else { + session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } + + return session +} + +func (session *Session) GroupBy(keys string) *Session { + session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys) + return session +} + +func (session *Session) Having(conditions string) *Session { + session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions) + return session +} + +func (session *Session) Begin() { + +} + +func (session *Session) Rollback() { + +} + +func (session *Session) Commit() { + for _, statement := range session.Statements { + sql := statement.generateSql() + session.Exec(sql) + } +} + +func (session *Session) TableName(bean interface{}) string { + return session.Mapper.Obj2Table(StructName(bean)) +} + +func (session *Session) newStatement() { + state := Statement{} + state.Session = session + session.Statements = append(session.Statements, state) + session.CurStatementIdx = len(session.Statements) - 1 +} + +func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { + dataStruct := reflect.Indirect(reflect.ValueOf(obj)) + if dataStruct.Kind() != reflect.Struct { + return errors.New("expected a pointer to a struct") + } + + tablName := session.TableName(obj) + table := session.Engine.Tables[tablName] + + for key, data := range objMap { + structField := dataStruct.FieldByName(table.Columns[key].FieldName) + if !structField.CanSet() { + continue + } + + var v interface{} + + switch structField.Type().Kind() { + case reflect.Slice: + v = data + case reflect.String: + v = string(data) + case reflect.Bool: + v = string(data) == "1" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + x, err := strconv.Atoi(string(data)) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + v = x + case reflect.Int64: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + v = x + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + v = x + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + v = x + //Now only support Time type + case reflect.Struct: + if structField.Type().String() != "time.Time" { + return errors.New("unsupported struct type in Scan: " + structField.Type().String()) + } + + x, err := time.Parse("2006-01-02 15:04:05", string(data)) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data)) + + if err != nil { + return errors.New("unsupported time format: " + string(data)) + } + } + + v = x + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + } + + structField.Set(reflect.ValueOf(v)) + } + + return nil +} + +func (session *Session) Get(output interface{}) error { + statement := session.AutoStatement() + session.Limit(1) + tableName := session.TableName(output) + table := session.Engine.Tables[tableName] + statement.Table = &table + + resultsSlice, err := session.FindMap(statement) + if err != nil { + return err + } + if len(resultsSlice) == 0 { + return nil + } else if len(resultsSlice) == 1 { + results := resultsSlice[0] + err := session.scanMapIntoStruct(output, results) + if err != nil { + return err + } + } else { + return errors.New("More than one record") + } + return nil +} + +func (session *Session) Count(bean interface{}) (int64, error) { + statement := session.AutoStatement() + session.Limit(1) + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + statement.Table = &table + + resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr) + if err != nil { + return 0, err + } + + var total int64 = 0 + for _, results := range resultsSlice { + total, err = strconv.ParseInt(string(results["total"]), 10, 64) + break + } + + return int64(total), err +} + +func (session *Session) Find(rowsSlicePtr interface{}) error { + statement := session.AutoStatement() + + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice { + return errors.New("needs a pointer to a slice") + } + + sliceElementType := sliceValue.Type().Elem() + + tableName := session.Mapper.Obj2Table(Type2StructName(sliceElementType)) + table := session.Engine.Tables[tableName] + statement.Table = &table + + resultsSlice, err := session.FindMap(statement) + if err != nil { + return err + } + + for _, results := range resultsSlice { + newValue := reflect.New(sliceElementType) + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err + } + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + return nil +} + +func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSlice []map[string][]byte, err error) { + if session.Engine.ShowSQL { + fmt.Println(sqls) + } + s, err := session.Db.Prepare(sqls) + if err != nil { + return nil, err + } + defer s.Close() + res, err := s.Query(paramStr...) + if err != nil { + return nil, err + } + defer res.Close() + fields, err := res.Columns() + if err != nil { + return nil, err + } + for res.Next() { + result := make(map[string][]byte) + var scanResultContainers []interface{} + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := res.Scan(scanResultContainers...); err != nil { + return nil, err + } + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + continue + } + aa := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + var str string + switch aa.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + + result[key] = []byte(str) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + result[key] = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + result[key] = []byte(str) + case reflect.Slice: + if aa.Elem().Kind() == reflect.Uint8 { + result[key] = rawValue.Interface().([]byte) + break + } + case reflect.String: + str = vv.String() + result[key] = []byte(str) + //时间类型 + case reflect.Struct: + str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") + result[key] = []byte(str) + } + + } + resultsSlice = append(resultsSlice, result) + } + return resultsSlice, nil +} + +func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) { + sqls := statement.generateSql() + return session.SQL2Map(sqls, statement.ParamStr) +} + +func (session *Session) Insert(beans ...interface{}) (int64, error) { + var lastId int64 = -1 + for _, bean := range beans { + lastId, err := session.InsertOne(bean) + if err != nil { + return lastId, err + } + } + return lastId, nil +} + +func (session *Session) InsertOne(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + + colNames := make([]string, 0) + colPlaces := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + val := fieldValue.Interface() + if col.AutoIncrement { + if fieldValue.Int() == 0 { + continue + } + } + args = append(args, val) + colNames = append(colNames, col.Name) + colPlaces = append(colPlaces, "?") + } + + statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", + session.Engine.QuoteIdentifier, + tableName, + session.Engine.QuoteIdentifier, + strings.Join(colNames, ", "), + strings.Join(colPlaces, ", ")) + + if session.Engine.ShowSQL { + fmt.Println(statement) + } + + res, err := session.Exec(statement, args...) + if err != nil { + return -1, err + } + + id, err := res.LastInsertId() + + if err != nil { + return -1, err + } + return id, nil +} + +func (session *Session) Update(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + fieldType := reflect.TypeOf(fieldValue.Interface()) + val := fieldValue.Interface() + switch fieldType.Kind() { + case reflect.String: + if fieldValue.String() == "" { + continue + } + case reflect.Int, reflect.Int32, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if t.IsZero() { + continue + } + } + default: + continue + } + args = append(args, val) + colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") + } + + var condition = "" + st := session.CurrentStatement() + if st != nil && st.WhereStr != "" { + condition = fmt.Sprintf("WHERE %v", st.WhereStr) + } + + if condition == "" { + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName) + if fieldValue.Int() != 0 { + condition = fmt.Sprintf("WHERE %v = ?", table.PKColumn().Name) + args = append(args, fieldValue.Interface()) + } + } + + statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v", + session.Engine.QuoteIdentifier, + tableName, + session.Engine.QuoteIdentifier, + strings.Join(colNames, ", "), + condition) + + if session.Engine.ShowSQL { + fmt.Println(statement) + } + + res, err := session.Exec(statement, args...) + if err != nil { + return -1, err + } + + id, err := res.RowsAffected() + + if err != nil { + return -1, err + } + return id, nil +} + +func (session *Session) Delete(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + fieldType := reflect.TypeOf(fieldValue.Interface()) + val := fieldValue.Interface() + switch fieldType.Kind() { + case reflect.String: + if fieldValue.String() == "" { + continue + } + case reflect.Int, reflect.Int32, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if t.IsZero() { + continue + } + } + default: + continue + } + args = append(args, val) + colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") + } + + var condition = "" + st := session.CurrentStatement() + if st != nil && st.WhereStr != "" { + condition = fmt.Sprintf("WHERE %v", st.WhereStr) + if len(colNames) > 0 { + condition += " and " + condition += strings.Join(colNames, " and ") + } + } else { + condition = "WHERE " + strings.Join(colNames, " and ") + } + + statement := fmt.Sprintf("DELETE FROM %v%v%v %v", + session.Engine.QuoteIdentifier, + tableName, + session.Engine.QuoteIdentifier, + condition) + + if session.Engine.ShowSQL { + fmt.Println(statement) + } + + res, err := session.Exec(statement, args...) + if err != nil { + return -1, err + } + + id, err := res.RowsAffected() + + if err != nil { + return -1, err + } + return id, nil +} diff --git a/statement.go b/statement.go new file mode 100644 index 00000000..868ad018 --- /dev/null +++ b/statement.go @@ -0,0 +1,144 @@ +package xorm + +import ( + "fmt" +) + +type Statement struct { + TableName string + Table *Table + Session *Session + LimitStr int + OffsetStr int + WhereStr string + ParamStr []interface{} + OrderStr string + JoinStr string + GroupByStr string + HavingStr string +} + +func (statement *Statement) Limit(start int, size ...int) *Statement { + statement.LimitStr = start + if len(size) > 0 { + statement.OffsetStr = size[0] + } + return statement +} + +func (statement *Statement) Offset(offset int) *Statement { + statement.OffsetStr = offset + return statement +} + +func (statement *Statement) OrderBy(order string) *Statement { + statement.OrderStr = order + return statement +} + +func (statement *Statement) Select(colums string) *Statement { + //statement.ColumnStr = colums + return statement +} + +func (statement Statement) generateSql() string { + columnStr := statement.Table.ColumnStr() + return statement.genSelectSql(columnStr) +} + +func (statement Statement) genCountSql() string { + return statement.genSelectSql("count(*) as total") +} + +func (statement Statement) genSelectSql(columnStr string) (a string) { + session := statement.Session + if session.Engine.Protocol == "mssql" { + if statement.OffsetStr > 0 { + a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", + statement.Table.PKColumn().Name, + columnStr, + statement.Table.Name) + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + a = fmt.Sprintf("select %v from (%v) "+ + "as a where rownum between %v and %v", + columnStr, + a, + statement.OffsetStr, + statement.LimitStr) + } else if statement.LimitStr > 0 { + a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name) + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + } else { + a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + } + } else { + a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) + if statement.JoinStr != "" { + a = fmt.Sprintf("%v %v", a, statement.JoinStr) + } + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + if statement.OffsetStr > 0 { + a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr) + } else if statement.LimitStr > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr) + } + } + return +} + +/*func (statement *Statement) genInsertSQL() string { + table = statement.Table + colNames := make([]string, len(table.Columns)) + for idx, col := range table.Columns { + if col.Name == "" { + continue + } + colNames[idx] = col.Name + } + return strings.Join(colNames, ", ") + + colNames := make([]string, len(table.Columns)) + for idx, col := range table.Columns { + if col.Name == "" { + continue + } + colNames[idx] = "?" + } + strings.Join(colNames, ", ") +}*/ diff --git a/xorm.go b/xorm.go new file mode 100644 index 00000000..2acebd95 --- /dev/null +++ b/xorm.go @@ -0,0 +1,47 @@ +package xorm + +import ( + "strings" +) + +// 'sqlite:///foo.db' +// 'sqlite:////Uses/lunny/foo.db' +// 'sqlite:///:memory:' +// '://:@/?charset=' +func Create(schema string) Engine { + engine := Engine{} + engine.Mapper = SnakeMapper{} + engine.Tables = make(map[string]Table) + l := strings.Split(schema, "://") + if len(l) == 2 { + engine.Protocol = l[0] + if l[0] == "sqlite" { + engine.Charset = "utf8" + engine.AutoIncrement = "AUTOINCREMENT" + if l[1] == "/:memory:" { + engine.Others = l[1] + } else if strings.Index(l[1], "//") == 0 { + engine.Others = l[1][1:] + } else if strings.Index(l[1], "/") == 0 { + engine.Others = "." + l[1] + } + } else { + engine.AutoIncrement = "AUTO_INCREMENT" + x := strings.Split(l[1], ":") + engine.UserName = x[0] + y := strings.Split(x[1], "@") + engine.Password = y[0] + z := strings.Split(y[1], "/") + engine.Host = z[0] + a := strings.Split(z[1], "?") + engine.DBName = a[0] + if len(a) == 2 { + engine.Charset = strings.Split(a[1], "=")[1] + } else { + engine.Charset = "utf8" + } + } + } + + return engine +} diff --git a/xorm_test.go b/xorm_test.go new file mode 100644 index 00000000..f785851c --- /dev/null +++ b/xorm_test.go @@ -0,0 +1,190 @@ +package xorm_test + +import ( + "fmt" + _ "github.com/Go-SQL-Driver/MySQL" + //_ "github.com/ziutek/mymysql/godrv" + //_ "github.com/mattn/go-sqlite3" + "testing" + "time" + "xorm" +) + +/* +CREATE TABLE `userinfo` ( + `uid` INT(10) NULL AUTO_INCREMENT, + `username` VARCHAR(64) NULL, + `departname` VARCHAR(64) NULL, + `created` DATE NULL, + PRIMARY KEY (`uid`) +); +CREATE TABLE `userdeatail` ( + `uid` INT(10) NULL, + `intro` TEXT NULL, + `profile` TEXT NULL, + PRIMARY KEY (`uid`) +); +*/ + +type Userinfo struct { + Uid int `xorm:"id pk not null autoincr"` + Username string + Departname string + Alias string `xorm:"-"` + Created time.Time +} + +var engine xorm.Engine + +func TestCreateEngine(t *testing.T) { + engine = xorm.Create("mysql://root:123@localhost/test") + //engine = orm.Create("mymysql://root:123@localhost/test") + //engine = orm.Create("sqlite:///test.db") + engine.ShowSQL = true +} + +func TestDirectCreateTable(t *testing.T) { + err := engine.CreateTables(&Userinfo{}) + if err != nil { + t.Error(err) + } +} + +func TestMapper(t *testing.T) { + err := engine.UnMap(&Userinfo{}) + if err != nil { + t.Error(err) + } + + err = engine.Map(&Userinfo{}) + if err != nil { + t.Error(err) + } + + err = engine.DropAll() + if err != nil { + t.Error(err) + } + + err = engine.CreateAll() + if err != nil { + t.Error(err) + } +} + +func TestInsert(t *testing.T) { + user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()} + _, err := engine.Insert(&user) + if err != nil { + t.Error(err) + } +} + +func TestInsertAutoIncr(t *testing.T) { + // auto increment insert + user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err := engine.Insert(&user) + if err != nil { + t.Error(err) + } +} + +func TestInsertMulti(t *testing.T) { + user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()} + user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()} + _, err := engine.Insert(&user1, &user2) + if err != nil { + t.Error(err) + } +} + +func TestUpdate(t *testing.T) { + // update by id + user := Userinfo{Uid: 1, Username: "xxx"} + _, err := engine.Update(&user) + if err != nil { + t.Error(err) + } +} + +func TestDelete(t *testing.T) { + user := Userinfo{Uid: 1} + _, err := engine.Delete(&user) + if err != nil { + t.Error(err) + } +} + +func TestGet(t *testing.T) { + user := Userinfo{Uid: 2} + + err := engine.Get(&user) + if err != nil { + t.Error(err) + } + fmt.Println(user) +} + +func TestFind(t *testing.T) { + users := make([]Userinfo, 0) + + err := engine.Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestCount(t *testing.T) { + user := Userinfo{} + total, err := engine.Count(&user) + if err != nil { + t.Error(err) + } + fmt.Printf("Total %d records!!!", total) +} + +func TestWhere(t *testing.T) { + users := make([]Userinfo, 0) + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + } + err = session.Where("id > ?", 2).Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestLimit(t *testing.T) { + users := make([]Userinfo, 0) + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + } + err = session.Limit(2, 1).Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestOrder(t *testing.T) { + users := make([]Userinfo, 0) + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + } + err = session.OrderBy("id desc").Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestTransaction(*testing.T) { +}