commit 4969e8bf94d4df8b9351c3f62f8724234e8d3b3a Author: Lunny Xiao Date: Fri May 3 15:26:51 2013 +0800 init project diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..00268614 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/README.md b/README.md new file mode 100644 index 00000000..256aa34d --- /dev/null +++ b/README.md @@ -0,0 +1,127 @@ +orm +===== + +orm is an ORM for Go. It lets you map Go structs to tables in a database. + +Right now, it interfaces with Mysql/SQLite. The goal however is to add support for PostgreSQL/DB2/MS ADODB/ODBC/Oracle in the future. + +All in all, it's not entirely ready for advanced use yet, but it's getting there. + +Drivers for Go's sql package which support database/sql includes: + +Mysql:[github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) + +Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) + +SQLite:[github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + +### Installing xorm + go get github.com/lunny/xorm + +### Quick Start + +1. Create an database engine (for example: mysql) + +```go +engine := xorm.Create("mysql://root:123@localhost/test") +``` + +2. Define your struct + +```go +type User struct { + Id int + Name string + Age int `xorm:"-"` +} +``` + +for Simple Task, just use engine's functions: + +begin start, you should create a database and then we create the tables + +```go +err := engine.CreateTables(&User{}) +``` + +then, insert an struct to table + +```go +id, err := engine.Insert(&User{Name:"lunny"}) +``` + +or you want to update this struct + +```go +user := User{Id:1, Name:"xlw"} +rows, err := engine.Update(&user) +``` + +3. Fetch a single object by user +```go +var user = User{Id:27} +engine.Get(&user) + +var user = User{Name:"xlw"} +engine.Get(&user) +``` + +for deep use, you should create a session, this func will create a connection to db + +```go +session, err := engine.MakeSession() +defer session.Close() +if err != nil { + return +} +``` + +1. Fetch a single object by where +```go +var user Userinfo +session.Where("id=?", 27).Get(&user) + +var user2 Userinfo +session.Where(3).Get(&user2) // this is shorthand for the version above + +var user3 Userinfo +session.Where("name = ?", "john").Get(&user3) // more complex query + +var user4 Userinfo +session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex +``` + +2. Fetch multiple objects + +```go +var allusers []Userinfo +err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 + +var tenusers []Userinfo +err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0 + +var everyone []Userinfo +err := session.Find(&everyone) +``` + +###***About Map Rules*** +1. Struct and struct's fields name should be Pascal style, and the table and column's name default is us +for example: +The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname. +If the keyname is 'UserName' will turn into the select colum 'user_name' + +2. You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper + +another is use field tag, field tag support the below keywords: +[name] column name +pk the field is a primary key +int(11)/varchar(50) column type +autoincr auto incrment +[not ]null if column can be null value +unique unique +- this field is not map as a table column + +## LICENSE + + BSD License + [http://creativecommons.org/licenses/BSD/](http://creativecommons.org/licenses/BSD/) diff --git a/engine.go b/engine.go new file mode 100644 index 00000000..75575a74 --- /dev/null +++ b/engine.go @@ -0,0 +1,435 @@ +package xorm + +import ( + "database/sql" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +type SQLType struct { + Name string + DefaultLength int +} + +var ( + Int = SQLType{"int", 11} + Char = SQLType{"char", 1} + Varchar = SQLType{"varchar", 50} + Date = SQLType{"date", 24} + Decimal = SQLType{"decimal", 26} + Float = SQLType{"float", 31} + Double = SQLType{"double", 31} +) + +const ( + PQSQL = "pqsql" + MSSQL = "mssql" + SQLITE = "sqlite" + MYSQL = "mysql" + MYMYSQL = "mymysql" +) + +type Column struct { + Name string + FieldName string + SQLType SQLType + Length int + Nullable bool + Default string + IsUnique bool + IsPrimaryKey bool + AutoIncrement bool +} + +type Table struct { + Name string + Type reflect.Type + Columns map[string]Column + PrimaryKey string +} + +func (table *Table) ColumnStr() string { + colNames := make([]string, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + colNames = append(colNames, col.Name) + } + return strings.Join(colNames, ", ") +} + +func (table *Table) PlaceHolders() string { + colNames := make([]string, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + colNames = append(colNames, "?") + } + return strings.Join(colNames, ", ") +} + +func (table *Table) PKColumn() Column { + return table.Columns[table.PrimaryKey] +} + +type Engine struct { + Mapper IMapper + Protocol string + UserName string + Password string + Host string + Port int + DBName string + Charset string + Others string + Tables map[string]Table + AutoIncrement string + ShowSQL bool + QuoteIdentifier string +} + +func (e *Engine) OpenDB() (db *sql.DB, err error) { + db = nil + err = nil + if e.Protocol == "sqlite" { + // 'sqlite:///foo.db' + db, err = sql.Open("sqlite3", e.Others) + // 'sqlite:///:memory:' + } else if e.Protocol == "mysql" { + // 'mysql://:@/?charset=' + connstr := strings.Join([]string{e.UserName, ":", + e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "") + db, err = sql.Open(e.Protocol, connstr) + } else if e.Protocol == "mymysql" { + // DBNAME/USER/PASSWD + connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/") + db, err = sql.Open(e.Protocol, connstr) + // unix:SOCKPATH*DBNAME/USER/PASSWD + // unix:SOCKPATH,OPTIONS*DBNAME/USER/PASSWD + // tcp:ADDR*DBNAME/USER/PASSWD + // tcp:ADDR,OPTIONS*DBNAME/USER/PASSWD + } + + return +} + +func (engine *Engine) MakeSession() (session Session, err error) { + db, err := engine.OpenDB() + if err != nil { + return Session{}, err + } + if engine.Protocol == "pgsql" { + engine.QuoteIdentifier = "\"" + session = Session{Engine: engine, Db: db, ParamIteration: 1} + } else if engine.Protocol == "mssql" { + engine.QuoteIdentifier = "" + session = Session{Engine: engine, Db: db, ParamIteration: 1} + } else { + engine.QuoteIdentifier = "`" + session = Session{Engine: engine, Db: db, ParamIteration: 1} + } + session.Mapper = engine.Mapper + session.Init() + return +} + +func (sqlType SQLType) genSQL(length int) string { + if sqlType == Date { + return " datetime " + } + return sqlType.Name + "(" + strconv.Itoa(length) + ")" +} + +func (e *Engine) genCreateSQL(table *Table) string { + sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` (" + //fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey)) + for _, col := range table.Columns { + if col.Name != "" { + sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " " + if col.Nullable { + sql += " NULL " + } else { + sql += " NOT NULL " + } + //fmt.Println(key) + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + } + if col.AutoIncrement { + sql += e.AutoIncrement + " " + } + if col.IsUnique { + sql += "Unique " + } + sql += "," + } + } + sql = sql[:len(sql)-2] + ");" + if e.ShowSQL { + fmt.Println(sql) + } + return sql +} + +func (e *Engine) genDropSQL(table *Table) string { + sql := "DROP TABLE IF EXISTS `" + table.Name + "`;" + if e.ShowSQL { + fmt.Println(sql) + } + return sql +} + +func Type(bean interface{}) reflect.Type { + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + return reflect.TypeOf(sliceValue.Interface()) +} + +func Type2SQLType(t reflect.Type) (st SQLType) { + switch k := t.Kind(); k { + case reflect.Int, reflect.Int32, reflect.Int64: + st = Int + case reflect.String: + st = Varchar + case reflect.Struct: + if t == reflect.TypeOf(time.Time{}) { + st = Date + } + default: + st = Varchar + } + return +} + +/* +map an object into a table object +*/ +func (engine *Engine) MapOne(bean interface{}) Table { + t := Type(bean) + return engine.MapType(t) +} + +func (engine *Engine) MapType(t reflect.Type) Table { + table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} + table.Columns = make(map[string]Column) + var pkCol *Column = nil + var pkstr = "" + + for i := 0; i < t.NumField(); i++ { + tag := t.Field(i).Tag + ormTagStr := tag.Get("xorm") + var col Column + fieldType := t.Field(i).Type + + if ormTagStr != "" { + col = Column{FieldName: t.Field(i).Name} + ormTagStr = strings.ToLower(ormTagStr) + tags := strings.Split(ormTagStr, " ") + // TODO: + if len(tags) > 0 { + if tags[0] == "-" { + continue + } + for j, key := range tags { + switch k := strings.ToLower(key); k { + case "pk": + col.IsPrimaryKey = true + pkCol = &col + case "null": + col.Nullable = (tags[j-1] != "not") + case "autoincr": + col.AutoIncrement = true + case "default": + col.Default = tags[j+1] + case "int": + col.SQLType = Int + case "not": + default: + col.Name = k + } + } + if col.SQLType.Name == "" { + col.SQLType = Type2SQLType(fieldType) + col.Length = col.SQLType.DefaultLength + } + + if col.Name == "" { + col.Name = engine.Mapper.Obj2Table(t.Field(i).Name) + } + } + } + + if col.Name == "" { + sqlType := Type2SQLType(fieldType) + col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, + sqlType.DefaultLength, true, "", false, false, false} + } + table.Columns[col.Name] = col + if strings.ToLower(t.Field(i).Name) == "id" { + pkstr = col.Name + } + } + + if pkCol == nil { + if pkstr != "" { + col := table.Columns[pkstr] + col.IsPrimaryKey = true + col.AutoIncrement = true + col.Nullable = false + col.Length = Int.DefaultLength + table.PrimaryKey = col.Name + } + } else { + table.PrimaryKey = pkCol.Name + } + + return table +} + +func (engine *Engine) Map(beans ...interface{}) (e error) { + for _, bean := range beans { + //t := getBeanType(bean) + tableName := engine.Mapper.Obj2Table(StructName(bean)) + if _, ok := engine.Tables[tableName]; !ok { + table := engine.MapOne(bean) + engine.Tables[table.Name] = table + } + } + return +} + +func (engine *Engine) UnMap(beans ...interface{}) (e error) { + for _, bean := range beans { + //t := getBeanType(bean) + tableName := engine.Mapper.Obj2Table(StructName(bean)) + if _, ok := engine.Tables[tableName]; ok { + delete(engine.Tables, tableName) + } + } + return +} + +func (e *Engine) DropAll() error { + session, err := e.MakeSession() + session.Begin() + defer session.Close() + if err != nil { + return err + } + + for _, table := range e.Tables { + sql := e.genDropSQL(&table) + _, err = session.Exec(sql) + if err != nil { + session.Rollback() + break + } + } + session.Commit() + return err +} + +func (e *Engine) CreateTables(beans ...interface{}) error { + session, err := e.MakeSession() + session.Begin() + defer session.Close() + if err != nil { + return err + } + for _, bean := range beans { + table := e.MapOne(bean) + e.Tables[table.Name] = table + sql := e.genCreateSQL(&table) + _, err = session.Exec(sql) + if err != nil { + session.Rollback() + break + } + } + session.Commit() + return err +} + +func (e *Engine) CreateAll() error { + session, err := e.MakeSession() + session.Begin() + defer session.Close() + if err != nil { + return err + } + + for _, table := range e.Tables { + sql := e.genCreateSQL(&table) + _, err = session.Exec(sql) + if err != nil { + session.Rollback() + break + } + } + session.Commit() + return err +} + +func (engine *Engine) Insert(beans ...interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return -1, err + } + + return session.Insert(beans...) +} + +func (engine *Engine) Update(bean interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return -1, err + } + + return session.Update(bean) +} + +func (engine *Engine) Delete(bean interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return -1, err + } + + return session.Delete(bean) +} + +func (engine *Engine) Get(bean interface{}) error { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return err + } + + return session.Get(bean) +} + +func (engine *Engine) Find(beans interface{}) error { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return err + } + + return session.Find(beans) +} + +func (engine *Engine) Count(bean interface{}) (int64, error) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return 0, err + } + + return session.Count(bean) +} diff --git a/install b/install new file mode 100755 index 00000000..1ccc8fc6 --- /dev/null +++ b/install @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +if [ ! -f install ]; then +echo 'install must be run within its container folder' 1>&2 +exit 1 +fi + +CURDIR=`pwd` +NEWPATH="$GOPATH/src/${PWD##*/}" +if [ ! -d "$NEWPATH" ]; then +ln -s $CURDIR $NEWPATH +fi + +gofmt -w $CURDIR + +cd $NEWPATH +go install ${PWD##*/} +cd $CURDIR + +echo 'finished' diff --git a/mapper.go b/mapper.go new file mode 100644 index 00000000..a00682b5 --- /dev/null +++ b/mapper.go @@ -0,0 +1,80 @@ +package xorm + +import ( +//"reflect" +//"strings" +) + +type IMapper interface { + Obj2Table(string) string + Table2Obj(string) string +} + +type SnakeMapper struct { +} + +func snakeCasedName(name string) string { + newstr := make([]rune, 0) + firstTime := true + + for _, chr := range name { + if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { + if firstTime == true { + firstTime = false + } else { + newstr = append(newstr, '_') + } + chr -= ('A' - 'a') + } + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func Pascal2Sql(s string) (d string) { + d = "" + lastIdx := 0 + for i := 0; i < len(s); i++ { + if s[i] >= 'A' && s[i] <= 'Z' { + if lastIdx < i { + d += s[lastIdx+1 : i] + } + if i != 0 { + d += "_" + } + d += string(s[i] + 32) + lastIdx = i + } + } + d += s[lastIdx+1:] + return +} + +func (mapper SnakeMapper) Obj2Table(name string) string { + return snakeCasedName(name) +} + +func titleCasedName(name string) string { + newstr := make([]rune, 0) + upNextChar := true + + for _, chr := range name { + switch { + case upNextChar: + upNextChar = false + chr -= ('a' - 'A') + case chr == '_': + upNextChar = true + continue + } + + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func (mapper SnakeMapper) Table2Obj(name string) string { + return titleCasedName(name) +} diff --git a/session.go b/session.go new file mode 100644 index 00000000..784f2d86 --- /dev/null +++ b/session.go @@ -0,0 +1,601 @@ +package xorm + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +func getTypeName(obj interface{}) (typestr string) { + typ := reflect.TypeOf(obj) + typestr = typ.String() + + lastDotIndex := strings.LastIndex(typestr, ".") + if lastDotIndex != -1 { + typestr = typestr[lastDotIndex+1:] + } + + return +} + +func StructName(s interface{}) string { + v := reflect.TypeOf(s) + return Type2StructName(v) +} + +func Type2StructName(v reflect.Type) string { + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Name() +} + +type Session struct { + Db *sql.DB + Engine *Engine + Statements []Statement + Mapper IMapper + AutoCommit bool + ParamIteration int + CurStatementIdx int +} + +func (session *Session) Init() { + session.Statements = make([]Statement, 0) + /*session.Statement.TableName = "" + session.Statement.LimitStr = 0 + session.Statement.OffsetStr = 0 + session.Statement.WhereStr = "" + session.Statement.ParamStr = make([]interface{}, 0) + session.Statement.OrderStr = "" + session.Statement.JoinStr = "" + session.Statement.GroupByStr = "" + session.Statement.HavingStr = ""*/ + session.CurStatementIdx = -1 + + session.ParamIteration = 1 +} + +func (session *Session) Close() { + defer session.Db.Close() +} + +func (session *Session) CurrentStatement() *Statement { + if session.CurStatementIdx > -1 { + return &session.Statements[session.CurStatementIdx] + } + return nil +} + +func (session *Session) AutoStatement() *Statement { + if session.CurStatementIdx == -1 { + session.newStatement() + } + return session.CurrentStatement() +} + +//Execute sql +func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(finalQueryString) + if err != nil { + return nil, err + } + defer rs.Close() + + res, err := rs.Exec(args...) + if err != nil { + return nil, err + } + return res, nil +} + +func (session *Session) Where(querystring interface{}, args ...interface{}) *Session { + statement := session.AutoStatement() + switch querystring := querystring.(type) { + case string: + statement.WhereStr = querystring + case int: + if session.Engine.Protocol == "pgsql" { + statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration) + } else { + statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier) + session.ParamIteration++ + } + args = append(args, querystring) + } + statement.ParamStr = args + return session +} + +func (session *Session) Limit(start int, size ...int) *Session { + session.AutoStatement().LimitStr = start + if len(size) > 0 { + session.CurrentStatement().OffsetStr = size[0] + } + return session +} + +func (session *Session) Offset(offset int) *Session { + session.AutoStatement().OffsetStr = offset + return session +} + +func (session *Session) OrderBy(order string) *Session { + session.AutoStatement().OrderStr = order + return session +} + +//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (session *Session) Join(join_operator, tablename, condition string) *Session { + if session.AutoStatement().JoinStr != "" { + session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } else { + session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } + + return session +} + +func (session *Session) GroupBy(keys string) *Session { + session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys) + return session +} + +func (session *Session) Having(conditions string) *Session { + session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions) + return session +} + +func (session *Session) Begin() { + +} + +func (session *Session) Rollback() { + +} + +func (session *Session) Commit() { + for _, statement := range session.Statements { + sql := statement.generateSql() + session.Exec(sql) + } +} + +func (session *Session) TableName(bean interface{}) string { + return session.Mapper.Obj2Table(StructName(bean)) +} + +func (session *Session) newStatement() { + state := Statement{} + state.Session = session + session.Statements = append(session.Statements, state) + session.CurStatementIdx = len(session.Statements) - 1 +} + +func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { + dataStruct := reflect.Indirect(reflect.ValueOf(obj)) + if dataStruct.Kind() != reflect.Struct { + return errors.New("expected a pointer to a struct") + } + + tablName := session.TableName(obj) + table := session.Engine.Tables[tablName] + + for key, data := range objMap { + structField := dataStruct.FieldByName(table.Columns[key].FieldName) + if !structField.CanSet() { + continue + } + + var v interface{} + + switch structField.Type().Kind() { + case reflect.Slice: + v = data + case reflect.String: + v = string(data) + case reflect.Bool: + v = string(data) == "1" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: + x, err := strconv.Atoi(string(data)) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + v = x + case reflect.Int64: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + v = x + case reflect.Float32, reflect.Float64: + x, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return errors.New("arg " + key + " as float64: " + err.Error()) + } + v = x + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + v = x + //Now only support Time type + case reflect.Struct: + if structField.Type().String() != "time.Time" { + return errors.New("unsupported struct type in Scan: " + structField.Type().String()) + } + + x, err := time.Parse("2006-01-02 15:04:05", string(data)) + if err != nil { + x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data)) + + if err != nil { + return errors.New("unsupported time format: " + string(data)) + } + } + + v = x + default: + return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) + } + + structField.Set(reflect.ValueOf(v)) + } + + return nil +} + +func (session *Session) Get(output interface{}) error { + statement := session.AutoStatement() + session.Limit(1) + tableName := session.TableName(output) + table := session.Engine.Tables[tableName] + statement.Table = &table + + resultsSlice, err := session.FindMap(statement) + if err != nil { + return err + } + if len(resultsSlice) == 0 { + return nil + } else if len(resultsSlice) == 1 { + results := resultsSlice[0] + err := session.scanMapIntoStruct(output, results) + if err != nil { + return err + } + } else { + return errors.New("More than one record") + } + return nil +} + +func (session *Session) Count(bean interface{}) (int64, error) { + statement := session.AutoStatement() + session.Limit(1) + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + statement.Table = &table + + resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr) + if err != nil { + return 0, err + } + + var total int64 = 0 + for _, results := range resultsSlice { + total, err = strconv.ParseInt(string(results["total"]), 10, 64) + break + } + + return int64(total), err +} + +func (session *Session) Find(rowsSlicePtr interface{}) error { + statement := session.AutoStatement() + + sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) + if sliceValue.Kind() != reflect.Slice { + return errors.New("needs a pointer to a slice") + } + + sliceElementType := sliceValue.Type().Elem() + + tableName := session.Mapper.Obj2Table(Type2StructName(sliceElementType)) + table := session.Engine.Tables[tableName] + statement.Table = &table + + resultsSlice, err := session.FindMap(statement) + if err != nil { + return err + } + + for _, results := range resultsSlice { + newValue := reflect.New(sliceElementType) + err := session.scanMapIntoStruct(newValue.Interface(), results) + if err != nil { + return err + } + sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) + } + return nil +} + +func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSlice []map[string][]byte, err error) { + if session.Engine.ShowSQL { + fmt.Println(sqls) + } + s, err := session.Db.Prepare(sqls) + if err != nil { + return nil, err + } + defer s.Close() + res, err := s.Query(paramStr...) + if err != nil { + return nil, err + } + defer res.Close() + fields, err := res.Columns() + if err != nil { + return nil, err + } + for res.Next() { + result := make(map[string][]byte) + var scanResultContainers []interface{} + for i := 0; i < len(fields); i++ { + var scanResultContainer interface{} + scanResultContainers = append(scanResultContainers, &scanResultContainer) + } + if err := res.Scan(scanResultContainers...); err != nil { + return nil, err + } + for ii, key := range fields { + rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + + //if row is null then ignore + if rawValue.Interface() == nil { + continue + } + aa := reflect.TypeOf(rawValue.Interface()) + vv := reflect.ValueOf(rawValue.Interface()) + var str string + switch aa.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(vv.Int(), 10) + + result[key] = []byte(str) + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str = strconv.FormatUint(vv.Uint(), 10) + result[key] = []byte(str) + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(vv.Float(), 'f', -1, 64) + result[key] = []byte(str) + case reflect.Slice: + if aa.Elem().Kind() == reflect.Uint8 { + result[key] = rawValue.Interface().([]byte) + break + } + case reflect.String: + str = vv.String() + result[key] = []byte(str) + //时间类型 + case reflect.Struct: + str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") + result[key] = []byte(str) + } + + } + resultsSlice = append(resultsSlice, result) + } + return resultsSlice, nil +} + +func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) { + sqls := statement.generateSql() + return session.SQL2Map(sqls, statement.ParamStr) +} + +func (session *Session) Insert(beans ...interface{}) (int64, error) { + var lastId int64 = -1 + for _, bean := range beans { + lastId, err := session.InsertOne(bean) + if err != nil { + return lastId, err + } + } + return lastId, nil +} + +func (session *Session) InsertOne(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + + colNames := make([]string, 0) + colPlaces := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + val := fieldValue.Interface() + if col.AutoIncrement { + if fieldValue.Int() == 0 { + continue + } + } + args = append(args, val) + colNames = append(colNames, col.Name) + colPlaces = append(colPlaces, "?") + } + + statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", + session.Engine.QuoteIdentifier, + tableName, + session.Engine.QuoteIdentifier, + strings.Join(colNames, ", "), + strings.Join(colPlaces, ", ")) + + if session.Engine.ShowSQL { + fmt.Println(statement) + } + + res, err := session.Exec(statement, args...) + if err != nil { + return -1, err + } + + id, err := res.LastInsertId() + + if err != nil { + return -1, err + } + return id, nil +} + +func (session *Session) Update(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + fieldType := reflect.TypeOf(fieldValue.Interface()) + val := fieldValue.Interface() + switch fieldType.Kind() { + case reflect.String: + if fieldValue.String() == "" { + continue + } + case reflect.Int, reflect.Int32, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if t.IsZero() { + continue + } + } + default: + continue + } + args = append(args, val) + colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") + } + + var condition = "" + st := session.CurrentStatement() + if st != nil && st.WhereStr != "" { + condition = fmt.Sprintf("WHERE %v", st.WhereStr) + } + + if condition == "" { + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(table.PKColumn().FieldName) + if fieldValue.Int() != 0 { + condition = fmt.Sprintf("WHERE %v = ?", table.PKColumn().Name) + args = append(args, fieldValue.Interface()) + } + } + + statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v", + session.Engine.QuoteIdentifier, + tableName, + session.Engine.QuoteIdentifier, + strings.Join(colNames, ", "), + condition) + + if session.Engine.ShowSQL { + fmt.Println(statement) + } + + res, err := session.Exec(statement, args...) + if err != nil { + return -1, err + } + + id, err := res.RowsAffected() + + if err != nil { + return -1, err + } + return id, nil +} + +func (session *Session) Delete(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + if col.Name == "" { + continue + } + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + fieldType := reflect.TypeOf(fieldValue.Interface()) + val := fieldValue.Interface() + switch fieldType.Kind() { + case reflect.String: + if fieldValue.String() == "" { + continue + } + case reflect.Int, reflect.Int32, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if t.IsZero() { + continue + } + } + default: + continue + } + args = append(args, val) + colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") + } + + var condition = "" + st := session.CurrentStatement() + if st != nil && st.WhereStr != "" { + condition = fmt.Sprintf("WHERE %v", st.WhereStr) + if len(colNames) > 0 { + condition += " and " + condition += strings.Join(colNames, " and ") + } + } else { + condition = "WHERE " + strings.Join(colNames, " and ") + } + + statement := fmt.Sprintf("DELETE FROM %v%v%v %v", + session.Engine.QuoteIdentifier, + tableName, + session.Engine.QuoteIdentifier, + condition) + + if session.Engine.ShowSQL { + fmt.Println(statement) + } + + res, err := session.Exec(statement, args...) + if err != nil { + return -1, err + } + + id, err := res.RowsAffected() + + if err != nil { + return -1, err + } + return id, nil +} diff --git a/statement.go b/statement.go new file mode 100644 index 00000000..868ad018 --- /dev/null +++ b/statement.go @@ -0,0 +1,144 @@ +package xorm + +import ( + "fmt" +) + +type Statement struct { + TableName string + Table *Table + Session *Session + LimitStr int + OffsetStr int + WhereStr string + ParamStr []interface{} + OrderStr string + JoinStr string + GroupByStr string + HavingStr string +} + +func (statement *Statement) Limit(start int, size ...int) *Statement { + statement.LimitStr = start + if len(size) > 0 { + statement.OffsetStr = size[0] + } + return statement +} + +func (statement *Statement) Offset(offset int) *Statement { + statement.OffsetStr = offset + return statement +} + +func (statement *Statement) OrderBy(order string) *Statement { + statement.OrderStr = order + return statement +} + +func (statement *Statement) Select(colums string) *Statement { + //statement.ColumnStr = colums + return statement +} + +func (statement Statement) generateSql() string { + columnStr := statement.Table.ColumnStr() + return statement.genSelectSql(columnStr) +} + +func (statement Statement) genCountSql() string { + return statement.genSelectSql("count(*) as total") +} + +func (statement Statement) genSelectSql(columnStr string) (a string) { + session := statement.Session + if session.Engine.Protocol == "mssql" { + if statement.OffsetStr > 0 { + a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", + statement.Table.PKColumn().Name, + columnStr, + statement.Table.Name) + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + a = fmt.Sprintf("select %v from (%v) "+ + "as a where rownum between %v and %v", + columnStr, + a, + statement.OffsetStr, + statement.LimitStr) + } else if statement.LimitStr > 0 { + a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name) + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + } else { + a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + } + } else { + a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) + if statement.JoinStr != "" { + a = fmt.Sprintf("%v %v", a, statement.JoinStr) + } + if statement.WhereStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + } + if statement.GroupByStr != "" { + a = fmt.Sprintf("%v %v", a, statement.GroupByStr) + } + if statement.HavingStr != "" { + a = fmt.Sprintf("%v %v", a, statement.HavingStr) + } + if statement.OrderStr != "" { + a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + } + if statement.OffsetStr > 0 { + a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr) + } else if statement.LimitStr > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr) + } + } + return +} + +/*func (statement *Statement) genInsertSQL() string { + table = statement.Table + colNames := make([]string, len(table.Columns)) + for idx, col := range table.Columns { + if col.Name == "" { + continue + } + colNames[idx] = col.Name + } + return strings.Join(colNames, ", ") + + colNames := make([]string, len(table.Columns)) + for idx, col := range table.Columns { + if col.Name == "" { + continue + } + colNames[idx] = "?" + } + strings.Join(colNames, ", ") +}*/ diff --git a/xorm.go b/xorm.go new file mode 100644 index 00000000..2acebd95 --- /dev/null +++ b/xorm.go @@ -0,0 +1,47 @@ +package xorm + +import ( + "strings" +) + +// 'sqlite:///foo.db' +// 'sqlite:////Uses/lunny/foo.db' +// 'sqlite:///:memory:' +// '://:@/?charset=' +func Create(schema string) Engine { + engine := Engine{} + engine.Mapper = SnakeMapper{} + engine.Tables = make(map[string]Table) + l := strings.Split(schema, "://") + if len(l) == 2 { + engine.Protocol = l[0] + if l[0] == "sqlite" { + engine.Charset = "utf8" + engine.AutoIncrement = "AUTOINCREMENT" + if l[1] == "/:memory:" { + engine.Others = l[1] + } else if strings.Index(l[1], "//") == 0 { + engine.Others = l[1][1:] + } else if strings.Index(l[1], "/") == 0 { + engine.Others = "." + l[1] + } + } else { + engine.AutoIncrement = "AUTO_INCREMENT" + x := strings.Split(l[1], ":") + engine.UserName = x[0] + y := strings.Split(x[1], "@") + engine.Password = y[0] + z := strings.Split(y[1], "/") + engine.Host = z[0] + a := strings.Split(z[1], "?") + engine.DBName = a[0] + if len(a) == 2 { + engine.Charset = strings.Split(a[1], "=")[1] + } else { + engine.Charset = "utf8" + } + } + } + + return engine +} diff --git a/xorm_test.go b/xorm_test.go new file mode 100644 index 00000000..f785851c --- /dev/null +++ b/xorm_test.go @@ -0,0 +1,190 @@ +package xorm_test + +import ( + "fmt" + _ "github.com/Go-SQL-Driver/MySQL" + //_ "github.com/ziutek/mymysql/godrv" + //_ "github.com/mattn/go-sqlite3" + "testing" + "time" + "xorm" +) + +/* +CREATE TABLE `userinfo` ( + `uid` INT(10) NULL AUTO_INCREMENT, + `username` VARCHAR(64) NULL, + `departname` VARCHAR(64) NULL, + `created` DATE NULL, + PRIMARY KEY (`uid`) +); +CREATE TABLE `userdeatail` ( + `uid` INT(10) NULL, + `intro` TEXT NULL, + `profile` TEXT NULL, + PRIMARY KEY (`uid`) +); +*/ + +type Userinfo struct { + Uid int `xorm:"id pk not null autoincr"` + Username string + Departname string + Alias string `xorm:"-"` + Created time.Time +} + +var engine xorm.Engine + +func TestCreateEngine(t *testing.T) { + engine = xorm.Create("mysql://root:123@localhost/test") + //engine = orm.Create("mymysql://root:123@localhost/test") + //engine = orm.Create("sqlite:///test.db") + engine.ShowSQL = true +} + +func TestDirectCreateTable(t *testing.T) { + err := engine.CreateTables(&Userinfo{}) + if err != nil { + t.Error(err) + } +} + +func TestMapper(t *testing.T) { + err := engine.UnMap(&Userinfo{}) + if err != nil { + t.Error(err) + } + + err = engine.Map(&Userinfo{}) + if err != nil { + t.Error(err) + } + + err = engine.DropAll() + if err != nil { + t.Error(err) + } + + err = engine.CreateAll() + if err != nil { + t.Error(err) + } +} + +func TestInsert(t *testing.T) { + user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()} + _, err := engine.Insert(&user) + if err != nil { + t.Error(err) + } +} + +func TestInsertAutoIncr(t *testing.T) { + // auto increment insert + user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err := engine.Insert(&user) + if err != nil { + t.Error(err) + } +} + +func TestInsertMulti(t *testing.T) { + user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()} + user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()} + _, err := engine.Insert(&user1, &user2) + if err != nil { + t.Error(err) + } +} + +func TestUpdate(t *testing.T) { + // update by id + user := Userinfo{Uid: 1, Username: "xxx"} + _, err := engine.Update(&user) + if err != nil { + t.Error(err) + } +} + +func TestDelete(t *testing.T) { + user := Userinfo{Uid: 1} + _, err := engine.Delete(&user) + if err != nil { + t.Error(err) + } +} + +func TestGet(t *testing.T) { + user := Userinfo{Uid: 2} + + err := engine.Get(&user) + if err != nil { + t.Error(err) + } + fmt.Println(user) +} + +func TestFind(t *testing.T) { + users := make([]Userinfo, 0) + + err := engine.Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestCount(t *testing.T) { + user := Userinfo{} + total, err := engine.Count(&user) + if err != nil { + t.Error(err) + } + fmt.Printf("Total %d records!!!", total) +} + +func TestWhere(t *testing.T) { + users := make([]Userinfo, 0) + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + } + err = session.Where("id > ?", 2).Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestLimit(t *testing.T) { + users := make([]Userinfo, 0) + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + } + err = session.Limit(2, 1).Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestOrder(t *testing.T) { + users := make([]Userinfo, 0) + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + } + err = session.OrderBy("id desc").Find(&users) + if err != nil { + t.Error(err) + } + fmt.Println(users) +} + +func TestTransaction(*testing.T) { +}