diff --git a/.gitignore b/.gitignore index 00268614..d44e3d8f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.o *.a *.so +*.db # Folders _obj diff --git a/README.md b/README.md index 08fd51dd..5c245e5b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -xorm -===== +# xorm +=========== [中文](README_CN.md) @@ -11,104 +11,105 @@ All in all, it's not entirely ready for advanced use yet, but it's getting there Drivers for Go's sql package which support database/sql includes: -Mysql:[github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) +Mysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) -Mysql:[github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) +Mysql: [github.com/Go-SQL-Driver/MySQL](https://github.com/Go-SQL-Driver/MySQL) -SQLite:[github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) +SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) -### Installing xorm - go get github.com/lunny/xorm +## Installing xorm + + go get github.com/lunny/xorm -### Quick Start +## Quick Start -1.Create an database engine (for example: mysql) +1.Create a database engine (for example: mysql) + + engine := xorm.Create("mysql://root:123@localhost/test") -```go -engine := xorm.Create("mysql://root:123@localhost/test") -``` 2.Define your struct -```go -type User struct { - Id int - Name string - Age int `xorm:"-"` -} -``` + + type User struct { + Id int + Name string + Age int `xorm:"-"` + } + for Simple Task, just use engine's functions: -begin start, you should create a database and then we create the tables +before beginning, you should create a database in mysql and then we will create the tables. + + + err := engine.CreateTables(&User{}) -```go -err := engine.CreateTables(&User{}) -``` then, insert an struct to table -```go -id, err := engine.Insert(&User{Name:"lunny"}) -``` + + id, err := engine.Insert(&User{Name:"lunny"}) + or you want to update this struct -```go -user := User{Id:1, Name:"xlw"} -rows, err := engine.Update(&user) -``` + + user := User{Id:1, Name:"xlw"} + rows, err := engine.Update(&user) + 3.Fetch a single object by user -```go -var user = User{Id:27} -engine.Get(&user) -var user = User{Name:"xlw"} -engine.Get(&user) -``` + var user = User{Id:27} + engine.Get(&user) + var user = User{Name:"xlw"} + engine.Get(&user) + + +##Deep Use for deep use, you should create a session, this func will create a connection to db -```go -session, err := engine.MakeSession() -defer session.Close() -if err != nil { - return -} -``` + + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + return + } + 1.Fetch a single object by where -```go -var user Userinfo -session.Where("id=?", 27).Get(&user) -var user2 Userinfo -session.Where(3).Get(&user2) // this is shorthand for the version above + var user Userinfo + session.Where("id=?", 27).Get(&user) -var user3 Userinfo -session.Where("name = ?", "john").Get(&user3) // more complex query + var user2 Userinfo + session.Where(3).Get(&user2) // this is shorthand for the version above + + var user3 Userinfo + session.Where("name = ?", "john").Get(&user3) // more complex query + + var user4 Userinfo + session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex -var user4 Userinfo -session.Where("name = ? and age < ?", "john", 88).Get(&user4) // even more complex -``` 2.Fetch multiple objects -```go -var allusers []Userinfo -err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 -var tenusers []Userinfo -err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0 + var allusers []Userinfo + err := session.Where("id > ?", "3").Limit(10,20).Find(&allusers) //Get id>3 limit 10 offset 20 -var everyone []Userinfo -err := session.Find(&everyone) -``` + var tenusers []Userinfo + err := session.Where("id > ?", "3").Limit(10).Find(&tenusers) //Get id>3 limit 10 if omit offset the default is 0 -###***About Map Rules*** + var everyone []Userinfo + err := session.Find(&everyone) + + +##***Mapping Rules*** 1.Struct and struct's fields name should be Pascal style, and the table and column's name default is us for example: The structs Name 'UserInfo' will turn into the table name 'user_info', the same as the keyname. @@ -117,13 +118,22 @@ If the keyname is 'UserName' will turn into the select colum 'user_name' 2.You have two method to change the rule. One is implement your own Map interface according IMapper, you can find the interface in mapper.go and set it to engine.Mapper another is use field tag, field tag support the below keywords: -[name] column name -pk the field is a primary key -int(11)/varchar(50) column type -autoincr auto incrment -[not ]null if column can be null value -unique unique -- this field is not map as a table column +* [name] column name +* pk the field is a primary key +* int(11)/varchar(50) column type +* autoincr auto incrment +* [not ]null if column can be null value +* unique unique +* \- this field is not map as a table column + +##FAQ +1.How the xorm tag use both with json? + + use space + + type User struct { + User string `json:"user" orm:"user_id"` + } ## LICENSE diff --git a/engine.go b/engine.go index 75575a74..c2665fb3 100644 --- a/engine.go +++ b/engine.go @@ -6,22 +6,6 @@ import ( "reflect" "strconv" "strings" - "time" -) - -type SQLType struct { - Name string - DefaultLength int -} - -var ( - Int = SQLType{"int", 11} - Char = SQLType{"char", 1} - Varchar = SQLType{"varchar", 50} - Date = SQLType{"date", 24} - Decimal = SQLType{"decimal", 26} - Float = SQLType{"float", 31} - Double = SQLType{"double", 31} ) const ( @@ -32,51 +16,6 @@ const ( MYMYSQL = "mymysql" ) -type Column struct { - Name string - FieldName string - SQLType SQLType - Length int - Nullable bool - Default string - IsUnique bool - IsPrimaryKey bool - AutoIncrement bool -} - -type Table struct { - Name string - Type reflect.Type - Columns map[string]Column - PrimaryKey string -} - -func (table *Table) ColumnStr() string { - colNames := make([]string, 0) - for _, col := range table.Columns { - if col.Name == "" { - continue - } - colNames = append(colNames, col.Name) - } - return strings.Join(colNames, ", ") -} - -func (table *Table) PlaceHolders() string { - colNames := make([]string, 0) - for _, col := range table.Columns { - if col.Name == "" { - continue - } - colNames = append(colNames, "?") - } - return strings.Join(colNames, ", ") -} - -func (table *Table) PKColumn() Column { - return table.Columns[table.PrimaryKey] -} - type Engine struct { Mapper IMapper Protocol string @@ -91,21 +30,27 @@ type Engine struct { AutoIncrement string ShowSQL bool QuoteIdentifier string + Statement Statement +} + +func Type(bean interface{}) reflect.Type { + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + return reflect.TypeOf(sliceValue.Interface()) } func (e *Engine) OpenDB() (db *sql.DB, err error) { db = nil err = nil - if e.Protocol == "sqlite" { + if e.Protocol == SQLITE { // 'sqlite:///foo.db' db, err = sql.Open("sqlite3", e.Others) // 'sqlite:///:memory:' - } else if e.Protocol == "mysql" { + } else if e.Protocol == MYSQL { // 'mysql://:@/?charset=' connstr := strings.Join([]string{e.UserName, ":", e.Password, "@tcp(", e.Host, ":3306)/", e.DBName, "?charset=", e.Charset}, "") db, err = sql.Open(e.Protocol, connstr) - } else if e.Protocol == "mymysql" { + } else if e.Protocol == MYMYSQL { // DBNAME/USER/PASSWD connstr := strings.Join([]string{e.DBName, e.UserName, e.Password}, "/") db, err = sql.Open(e.Protocol, connstr) @@ -123,51 +68,95 @@ func (engine *Engine) MakeSession() (session Session, err error) { if err != nil { return Session{}, err } - if engine.Protocol == "pgsql" { + if engine.Protocol == PQSQL { engine.QuoteIdentifier = "\"" - session = Session{Engine: engine, Db: db, ParamIteration: 1} - } else if engine.Protocol == "mssql" { + session = Session{Engine: engine, Db: db} + } else if engine.Protocol == MSSQL { engine.QuoteIdentifier = "" - session = Session{Engine: engine, Db: db, ParamIteration: 1} + session = Session{Engine: engine, Db: db} } else { engine.QuoteIdentifier = "`" - session = Session{Engine: engine, Db: db, ParamIteration: 1} + session = Session{Engine: engine, Db: db} } session.Mapper = engine.Mapper session.Init() return } -func (sqlType SQLType) genSQL(length int) string { - if sqlType == Date { - return " datetime " +func (engine *Engine) Where(querystring string, args ...interface{}) *Engine { + engine.Statement.Where(querystring, args...) + return engine +} + +func (engine *Engine) Limit(limit int, start ...int) *Engine { + engine.Statement.Limit(limit, start...) + return engine +} + +func (engine *Engine) OrderBy(order string) *Engine { + engine.Statement.OrderBy(order) + return engine +} + +//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (engine *Engine) Join(join_operator, tablename, condition string) *Engine { + engine.Statement.Join(join_operator, tablename, condition) + return engine +} + +func (engine *Engine) GroupBy(keys string) *Engine { + engine.Statement.GroupBy(keys) + return engine +} + +func (engine *Engine) Having(conditions string) *Engine { + engine.Statement.Having(conditions) + return engine +} + +func (e *Engine) genColumnStr(col *Column) string { + sql := "`" + col.Name + "` " + if col.SQLType == Date { + sql += " datetime " + } else { + if e.Protocol == SQLITE && col.IsPrimaryKey { + sql += "integer" + } else { + sql += col.SQLType.Name + } + if e.Protocol != SQLITE { + if col.SQLType != Decimal { + sql += "(" + strconv.Itoa(col.Length) + ")" + } else { + sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")" + } + } } - return sqlType.Name + "(" + strconv.Itoa(length) + ")" + + if col.Nullable { + sql += " NULL " + } else { + sql += " NOT NULL " + } + //fmt.Println(key) + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + } + if col.IsAutoIncrement { + sql += e.AutoIncrement + " " + } + if col.IsUnique { + sql += "Unique " + } + return sql } func (e *Engine) genCreateSQL(table *Table) string { sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` (" //fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey)) for _, col := range table.Columns { - if col.Name != "" { - sql += "`" + col.Name + "` " + col.SQLType.genSQL(col.Length) + " " - if col.Nullable { - sql += " NULL " - } else { - sql += " NOT NULL " - } - //fmt.Println(key) - if col.IsPrimaryKey { - sql += "PRIMARY KEY " - } - if col.AutoIncrement { - sql += e.AutoIncrement + " " - } - if col.IsUnique { - sql += "Unique " - } - sql += "," - } + sql += e.genColumnStr(&col) + sql += "," } sql = sql[:len(sql)-2] + ");" if e.ShowSQL { @@ -184,27 +173,6 @@ func (e *Engine) genDropSQL(table *Table) string { return sql } -func Type(bean interface{}) reflect.Type { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - return reflect.TypeOf(sliceValue.Interface()) -} - -func Type2SQLType(t reflect.Type) (st SQLType) { - switch k := t.Kind(); k { - case reflect.Int, reflect.Int32, reflect.Int64: - st = Int - case reflect.String: - st = Varchar - case reflect.Struct: - if t == reflect.TypeOf(time.Time{}) { - st = Date - } - default: - st = Varchar - } - return -} - /* map an object into a table object */ @@ -242,7 +210,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { case "null": col.Nullable = (tags[j-1] != "not") case "autoincr": - col.AutoIncrement = true + col.IsAutoIncrement = true case "default": col.Default = tags[j+1] case "int": @@ -255,6 +223,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { if col.SQLType.Name == "" { col.SQLType = Type2SQLType(fieldType) col.Length = col.SQLType.DefaultLength + col.Length2 = col.SQLType.DefaultLength2 } if col.Name == "" { @@ -266,7 +235,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { if col.Name == "" { sqlType := Type2SQLType(fieldType) col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, - sqlType.DefaultLength, true, "", false, false, false} + sqlType.DefaultLength, sqlType.DefaultLength2, true, "", false, false, false} } table.Columns[col.Name] = col if strings.ToLower(t.Field(i).Name) == "id" { @@ -278,7 +247,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { if pkstr != "" { col := table.Columns[pkstr] col.IsPrimaryKey = true - col.AutoIncrement = true + col.IsAutoIncrement = true col.Nullable = false col.Length = Int.DefaultLength table.PrimaryKey = col.Name @@ -292,7 +261,6 @@ func (engine *Engine) MapType(t reflect.Type) Table { func (engine *Engine) Map(beans ...interface{}) (e error) { for _, bean := range beans { - //t := getBeanType(bean) tableName := engine.Mapper.Obj2Table(StructName(bean)) if _, ok := engine.Tables[tableName]; !ok { table := engine.MapOne(bean) @@ -304,7 +272,6 @@ func (engine *Engine) Map(beans ...interface{}) (e error) { func (engine *Engine) UnMap(beans ...interface{}) (e error) { for _, bean := range beans { - //t := getBeanType(bean) tableName := engine.Mapper.Obj2Table(StructName(bean)) if _, ok := engine.Tables[tableName]; ok { delete(engine.Tables, tableName) @@ -380,7 +347,9 @@ func (engine *Engine) Insert(beans ...interface{}) (int64, error) { if err != nil { return -1, err } - + defer engine.Statement.Init() + engine.Statement.Session = &session + session.SetStatement(&engine.Statement) return session.Insert(beans...) } @@ -390,7 +359,9 @@ func (engine *Engine) Update(bean interface{}) (int64, error) { if err != nil { return -1, err } - + defer engine.Statement.Init() + engine.Statement.Session = &session + session.SetStatement(&engine.Statement) return session.Update(bean) } @@ -400,7 +371,9 @@ func (engine *Engine) Delete(bean interface{}) (int64, error) { if err != nil { return -1, err } - + defer engine.Statement.Init() + engine.Statement.Session = &session + session.SetStatement(&engine.Statement) return session.Delete(bean) } @@ -410,7 +383,9 @@ func (engine *Engine) Get(bean interface{}) error { if err != nil { return err } - + defer engine.Statement.Init() + engine.Statement.Session = &session + session.SetStatement(&engine.Statement) return session.Get(bean) } @@ -420,7 +395,9 @@ func (engine *Engine) Find(beans interface{}) error { if err != nil { return err } - + defer engine.Statement.Init() + engine.Statement.Session = &session + session.SetStatement(&engine.Statement) return session.Find(beans) } @@ -430,6 +407,8 @@ func (engine *Engine) Count(bean interface{}) (int64, error) { if err != nil { return 0, err } - + defer engine.Statement.Init() + engine.Statement.Session = &session + session.SetStatement(&engine.Statement) return session.Count(bean) } diff --git a/session.go b/session.go index 784f2d86..6aaa7f89 100644 --- a/session.go +++ b/session.go @@ -37,30 +37,28 @@ func Type2StructName(v reflect.Type) string { type Session struct { Db *sql.DB Engine *Engine + Tx *sql.Tx Statements []Statement Mapper IMapper - AutoCommit bool - ParamIteration int + IsAutoCommit bool + IsAutoRollback bool CurStatementIdx int } func (session *Session) Init() { session.Statements = make([]Statement, 0) - /*session.Statement.TableName = "" - session.Statement.LimitStr = 0 - session.Statement.OffsetStr = 0 - session.Statement.WhereStr = "" - session.Statement.ParamStr = make([]interface{}, 0) - session.Statement.OrderStr = "" - session.Statement.JoinStr = "" - session.Statement.GroupByStr = "" - session.Statement.HavingStr = ""*/ session.CurStatementIdx = -1 - - session.ParamIteration = 1 + session.IsAutoCommit = true + session.IsAutoRollback = false } func (session *Session) Close() { + rollbackfunc := func() { + if session.IsAutoRollback { + session.Rollback() + } + } + defer rollbackfunc() defer session.Db.Close() } @@ -78,102 +76,84 @@ func (session *Session) AutoStatement() *Statement { return session.CurrentStatement() } -//Execute sql -func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) { - rs, err := session.Db.Prepare(finalQueryString) - if err != nil { - return nil, err - } - defer rs.Close() - - res, err := rs.Exec(args...) - if err != nil { - return nil, err - } - return res, nil -} - -func (session *Session) Where(querystring interface{}, args ...interface{}) *Session { +func (session *Session) Where(querystring string, args ...interface{}) *Session { statement := session.AutoStatement() - switch querystring := querystring.(type) { - case string: - statement.WhereStr = querystring - case int: - if session.Engine.Protocol == "pgsql" { - statement.WhereStr = fmt.Sprintf("%v%v%v = $%v", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier, session.ParamIteration) - } else { - statement.WhereStr = fmt.Sprintf("%v%v%v = ?", session.Engine.QuoteIdentifier, statement.Table.PKColumn().Name, session.Engine.QuoteIdentifier) - session.ParamIteration++ - } - args = append(args, querystring) - } - statement.ParamStr = args + statement.Where(querystring, args...) return session } -func (session *Session) Limit(start int, size ...int) *Session { - session.AutoStatement().LimitStr = start - if len(size) > 0 { - session.CurrentStatement().OffsetStr = size[0] - } - return session -} - -func (session *Session) Offset(offset int) *Session { - session.AutoStatement().OffsetStr = offset +func (session *Session) Limit(limit int, start ...int) *Session { + statement := session.AutoStatement() + statement.Limit(limit, start...) return session } func (session *Session) OrderBy(order string) *Session { - session.AutoStatement().OrderStr = order + statement := session.AutoStatement() + statement.OrderBy(order) return session } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (session *Session) Join(join_operator, tablename, condition string) *Session { - if session.AutoStatement().JoinStr != "" { - session.CurrentStatement().JoinStr = session.CurrentStatement().JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) - } else { - session.CurrentStatement().JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) - } - + statement := session.AutoStatement() + statement.Join(join_operator, tablename, condition) return session } func (session *Session) GroupBy(keys string) *Session { - session.AutoStatement().GroupByStr = fmt.Sprintf("GROUP BY %v", keys) + statement := session.AutoStatement() + statement.GroupBy(keys) return session } func (session *Session) Having(conditions string) *Session { - session.AutoStatement().HavingStr = fmt.Sprintf("HAVING %v", conditions) + statement := session.AutoStatement() + statement.Having(conditions) return session } -func (session *Session) Begin() { - +func (session *Session) Begin() error { + session.IsAutoCommit = false + session.IsAutoRollback = true + tx, err := session.Db.Begin() + session.Tx = tx + return err } -func (session *Session) Rollback() { - +func (session *Session) Rollback() error { + return session.Tx.Rollback() } -func (session *Session) Commit() { - for _, statement := range session.Statements { - sql := statement.generateSql() - session.Exec(sql) - } +func (session *Session) Commit() error { + return session.Tx.Commit() } func (session *Session) TableName(bean interface{}) string { return session.Mapper.Obj2Table(StructName(bean)) } +func (session *Session) SetStatement(statement *Statement) { + if session.CurStatementIdx == len(session.Statements)-1 { + session.Statements = append(session.Statements, *statement) + } else { + session.Statements[session.CurStatementIdx+1] = *statement + } + session.CurStatementIdx = session.CurStatementIdx + 1 +} + func (session *Session) newStatement() { - state := Statement{} - state.Session = session - session.Statements = append(session.Statements, state) - session.CurStatementIdx = len(session.Statements) - 1 + if session.CurStatementIdx == len(session.Statements)-1 { + state := Statement{Session: session} + state.Init() + session.Statements = append(session.Statements, state) + } + session.CurStatementIdx = session.CurStatementIdx + 1 +} + +func (session *Session) clearStatment() { + session.Statements[session.CurStatementIdx].Init() + session.CurStatementIdx = session.CurStatementIdx - 1 } func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]byte) error { @@ -250,13 +230,32 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b return nil } -func (session *Session) Get(output interface{}) error { +//Execute sql +func (session *Session) Exec(finalQueryString string, args ...interface{}) (sql.Result, error) { + rs, err := session.Db.Prepare(finalQueryString) + if err != nil { + return nil, err + } + defer rs.Close() + + res, err := rs.Exec(args...) + if err != nil { + return nil, err + } + return res, nil +} + +func (session *Session) Get(bean interface{}) error { statement := session.AutoStatement() - session.Limit(1) - tableName := session.TableName(output) + statement.Limit(1) + tableName := session.TableName(bean) table := session.Engine.Tables[tableName] statement.Table = &table + colNames, args := session.BuildConditions(&table, bean) + statement.ColumnStr = strings.Join(colNames, " and ") + statement.BeanArgs = args + resultsSlice, err := session.FindMap(statement) if err != nil { return err @@ -265,7 +264,7 @@ func (session *Session) Get(output interface{}) error { return nil } else if len(resultsSlice) == 1 { results := resultsSlice[0] - err := session.scanMapIntoStruct(output, results) + err := session.scanMapIntoStruct(bean, results) if err != nil { return err } @@ -277,12 +276,15 @@ func (session *Session) Get(output interface{}) error { func (session *Session) Count(bean interface{}) (int64, error) { statement := session.AutoStatement() - session.Limit(1) tableName := session.TableName(bean) table := session.Engine.Tables[tableName] statement.Table = &table - resultsSlice, err := session.SQL2Map(statement.genCountSql(), statement.ParamStr) + colNames, args := session.BuildConditions(&table, bean) + statement.ColumnStr = strings.Join(colNames, " and ") + statement.BeanArgs = args + + resultsSlice, err := session.SQL2Map(statement.genCountSql(), append(statement.Params, statement.BeanArgs...)) if err != nil { return 0, err } @@ -346,6 +348,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli } for res.Next() { result := make(map[string][]byte) + //scanResultContainers := make([]interface{}, len(fields)) var scanResultContainers []interface{} for i := 0; i < len(fields); i++ { var scanResultContainer interface{} @@ -367,7 +370,6 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli switch aa.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: str = strconv.FormatInt(vv.Int(), 10) - result[key] = []byte(str) case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: str = strconv.FormatUint(vv.Uint(), 10) @@ -397,7 +399,7 @@ func (session *Session) SQL2Map(sqls string, paramStr []interface{}) (resultsSli func (session *Session) FindMap(statement *Statement) (resultsSlice []map[string][]byte, err error) { sqls := statement.generateSql() - return session.SQL2Map(sqls, statement.ParamStr) + return session.SQL2Map(sqls, append(statement.Params, statement.BeanArgs...)) } func (session *Session) Insert(beans ...interface{}) (int64, error) { @@ -421,7 +423,7 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { for _, col := range table.Columns { fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) val := fieldValue.Interface() - if col.AutoIncrement { + if col.IsAutoIncrement { if fieldValue.Int() == 0 { continue } @@ -442,7 +444,14 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { fmt.Println(statement) } - res, err := session.Exec(statement, args...) + var res sql.Result + var err error + + if session.IsAutoCommit { + res, err = session.Exec(statement, args...) + } else { + res, err = session.Tx.Exec(statement, args...) + } if err != nil { return -1, err } @@ -455,16 +464,10 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return id, nil } -func (session *Session) Update(bean interface{}) (int64, error) { - tableName := session.TableName(bean) - table := session.Engine.Tables[tableName] - +func (session *Session) BuildConditions(table *Table, bean interface{}) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { - if col.Name == "" { - continue - } fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) fieldType := reflect.TypeOf(fieldValue.Interface()) val := fieldValue.Interface() @@ -491,9 +494,18 @@ func (session *Session) Update(bean interface{}) (int64, error) { colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") } + return colNames, args +} + +func (session *Session) Update(bean interface{}) (int64, error) { + tableName := session.TableName(bean) + table := session.Engine.Tables[tableName] + colNames, args := session.BuildConditions(&table, bean) + var condition = "" - st := session.CurrentStatement() - if st != nil && st.WhereStr != "" { + st := session.AutoStatement() + defer session.clearStatment() + if st.WhereStr != "" { condition = fmt.Sprintf("WHERE %v", st.WhereStr) } @@ -516,7 +528,15 @@ func (session *Session) Update(bean interface{}) (int64, error) { fmt.Println(statement) } - res, err := session.Exec(statement, args...) + var res sql.Result + var err error + if session.IsAutoCommit { + fmt.Println("session.Exec") + res, err = session.Exec(statement, append(args, st.Params...)...) + } else { + fmt.Println("tx.Exec") + res, err = session.Tx.Exec(statement, append(args, st.Params...)...) + } if err != nil { return -1, err } @@ -532,42 +552,12 @@ func (session *Session) Update(bean interface{}) (int64, error) { func (session *Session) Delete(bean interface{}) (int64, error) { tableName := session.TableName(bean) table := session.Engine.Tables[tableName] - - colNames := make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns { - if col.Name == "" { - continue - } - fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) - fieldType := reflect.TypeOf(fieldValue.Interface()) - val := fieldValue.Interface() - switch fieldType.Kind() { - case reflect.String: - if fieldValue.String() == "" { - continue - } - case reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) - if t.IsZero() { - continue - } - } - default: - continue - } - args = append(args, val) - colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") - } + colNames, args := session.BuildConditions(&table, bean) var condition = "" - st := session.CurrentStatement() - if st != nil && st.WhereStr != "" { + st := session.AutoStatement() + defer session.clearStatment() + if st.WhereStr != "" { condition = fmt.Sprintf("WHERE %v", st.WhereStr) if len(colNames) > 0 { condition += " and " @@ -587,7 +577,13 @@ func (session *Session) Delete(bean interface{}) (int64, error) { fmt.Println(statement) } - res, err := session.Exec(statement, args...) + var res sql.Result + var err error + if session.IsAutoCommit { + res, err = session.Exec(statement, append(st.Params, args...)...) + } else { + res, err = session.Tx.Exec(statement, append(st.Params, args...)...) + } if err != nil { return -1, err } diff --git a/statement.go b/statement.go index 868ad018..f683eba1 100644 --- a/statement.go +++ b/statement.go @@ -5,40 +5,66 @@ import ( ) type Statement struct { - TableName string Table *Table Session *Session - LimitStr int - OffsetStr int + Start int + LimitN int WhereStr string - ParamStr []interface{} + Params []interface{} OrderStr string JoinStr string GroupByStr string HavingStr string + ColumnStr string + BeanArgs []interface{} } -func (statement *Statement) Limit(start int, size ...int) *Statement { - statement.LimitStr = start - if len(size) > 0 { - statement.OffsetStr = size[0] +func (statement *Statement) Init() { + statement.Table = nil + statement.Session = nil + statement.Start = 0 + statement.LimitN = 0 + statement.WhereStr = "" + statement.Params = make([]interface{}, 0) + statement.OrderStr = "" + statement.JoinStr = "" + statement.GroupByStr = "" + statement.HavingStr = "" + statement.ColumnStr = "" + statement.BeanArgs = make([]interface{}, 0) +} + +func (statement *Statement) Where(querystring string, args ...interface{}) { + statement.WhereStr = querystring + statement.Params = args +} + +func (statement *Statement) Limit(limit int, start ...int) { + statement.LimitN = limit + if len(start) > 0 { + statement.Start = start[0] } - return statement } -func (statement *Statement) Offset(offset int) *Statement { - statement.OffsetStr = offset - return statement -} - -func (statement *Statement) OrderBy(order string) *Statement { +func (statement *Statement) OrderBy(order string) { statement.OrderStr = order - return statement } -func (statement *Statement) Select(colums string) *Statement { - //statement.ColumnStr = colums - return statement +//The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN +func (statement *Statement) Join(join_operator, tablename, condition string) { + if statement.JoinStr != "" { + statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } else { + statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + } +} + +func (statement *Statement) GroupBy(keys string) { + statement.GroupByStr = fmt.Sprintf("GROUP BY %v", keys) +} + +func (statement *Statement) Having(conditions string) { + statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) } func (statement Statement) generateSql() string { @@ -50,27 +76,41 @@ func (statement Statement) genCountSql() string { return statement.genSelectSql("count(*) as total") } +func (statement Statement) genExecSql() string { + return "" +} + func (statement Statement) genSelectSql(columnStr string) (a string) { session := statement.Session if session.Engine.Protocol == "mssql" { - if statement.OffsetStr > 0 { + if statement.Start > 0 { a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", statement.Table.PKColumn().Name, columnStr, statement.Table.Name) if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + if statement.ColumnStr != "" { + a = fmt.Sprintf("%v and %v", a, statement.ColumnStr) + } + } else if statement.ColumnStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr) } a = fmt.Sprintf("select %v from (%v) "+ "as a where rownum between %v and %v", columnStr, a, - statement.OffsetStr, - statement.LimitStr) - } else if statement.LimitStr > 0 { - a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitStr, columnStr, statement.Table.Name) + statement.Start, + statement.LimitN) + } else if statement.LimitN > 0 { + a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.Table.Name) if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + if statement.ColumnStr != "" { + a = fmt.Sprintf("%v and %v", a, statement.ColumnStr) + } + } else if statement.ColumnStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr) } if statement.GroupByStr != "" { a = fmt.Sprintf("%v %v", a, statement.GroupByStr) @@ -85,6 +125,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + if statement.ColumnStr != "" { + a = fmt.Sprintf("%v and %v", a, statement.ColumnStr) + } + } else if statement.ColumnStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr) } if statement.GroupByStr != "" { a = fmt.Sprintf("%v %v", a, statement.GroupByStr) @@ -103,6 +148,11 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { } if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + if statement.ColumnStr != "" { + a = fmt.Sprintf("%v and %v", a, statement.ColumnStr) + } + } else if statement.ColumnStr != "" { + a = fmt.Sprintf("%v WHERE %v", a, statement.ColumnStr) } if statement.GroupByStr != "" { a = fmt.Sprintf("%v %v", a, statement.GroupByStr) @@ -113,10 +163,10 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if statement.OffsetStr > 0 { - a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.OffsetStr, statement.LimitStr) - } else if statement.LimitStr > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitStr) + if statement.Start > 0 { + a = fmt.Sprintf("%v LIMIT %v, %v", a, statement.Start, statement.LimitN) + } else if statement.LimitN > 0 { + a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } } return diff --git a/table.go b/table.go new file mode 100644 index 00000000..5503ed99 --- /dev/null +++ b/table.go @@ -0,0 +1,90 @@ +package xorm + +import ( + "reflect" + "strconv" + "strings" + "time" +) + +type SQLType struct { + Name string + DefaultLength int + DefaultLength2 int +} + +var ( + Int = SQLType{"int", 11, 0} + Char = SQLType{"char", 1, 0} + Bool = SQLType{"int", 1, 0} + Varchar = SQLType{"varchar", 50, 0} + Date = SQLType{"date", 24, 0} + Decimal = SQLType{"decimal", 26, 2} + Float = SQLType{"float", 31, 0} + Double = SQLType{"double", 31, 0} +) + +func (sqlType SQLType) genSQL(length int) string { + if sqlType == Date { + return " datetime " + } + return sqlType.Name + "(" + strconv.Itoa(length) + ")" +} + +func Type2SQLType(t reflect.Type) (st SQLType) { + switch k := t.Kind(); k { + case reflect.Int, reflect.Int32, reflect.Int64: + st = Int + case reflect.Bool: + st = Bool + case reflect.String: + st = Varchar + case reflect.Struct: + if t == reflect.TypeOf(time.Time{}) { + st = Date + } + default: + st = Varchar + } + return +} + +type Column struct { + Name string + FieldName string + SQLType SQLType + Length int + Length2 int + Nullable bool + Default string + IsUnique bool + IsPrimaryKey bool + IsAutoIncrement bool +} + +type Table struct { + Name string + Type reflect.Type + Columns map[string]Column + PrimaryKey string +} + +func (table *Table) ColumnStr() string { + colNames := make([]string, 0) + for _, col := range table.Columns { + colNames = append(colNames, col.Name) + } + return strings.Join(colNames, ", ") +} + +/*func (table *Table) PlaceHolders() string { + colNames := make([]string, 0) + for _, col := range table.Columns { + colNames = append(colNames, "?") + } + return strings.Join(colNames, ", ") +}*/ + +func (table *Table) PKColumn() Column { + return table.Columns[table.PrimaryKey] +} diff --git a/xorm_test.go b/xorm_test.go index f785851c..a0f181a5 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -3,8 +3,7 @@ package xorm_test import ( "fmt" _ "github.com/Go-SQL-Driver/MySQL" - //_ "github.com/ziutek/mymysql/godrv" - //_ "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" "testing" "time" "xorm" @@ -36,21 +35,14 @@ type Userinfo struct { var engine xorm.Engine -func TestCreateEngine(t *testing.T) { - engine = xorm.Create("mysql://root:123@localhost/test") - //engine = orm.Create("mymysql://root:123@localhost/test") - //engine = orm.Create("sqlite:///test.db") - engine.ShowSQL = true -} - -func TestDirectCreateTable(t *testing.T) { +func directCreateTable(t *testing.T) { err := engine.CreateTables(&Userinfo{}) if err != nil { t.Error(err) } } -func TestMapper(t *testing.T) { +func mapper(t *testing.T) { err := engine.UnMap(&Userinfo{}) if err != nil { t.Error(err) @@ -72,7 +64,7 @@ func TestMapper(t *testing.T) { } } -func TestInsert(t *testing.T) { +func insert(t *testing.T) { user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()} _, err := engine.Insert(&user) if err != nil { @@ -80,7 +72,7 @@ func TestInsert(t *testing.T) { } } -func TestInsertAutoIncr(t *testing.T) { +func insertAutoIncr(t *testing.T) { // auto increment insert user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()} _, err := engine.Insert(&user) @@ -89,7 +81,7 @@ func TestInsertAutoIncr(t *testing.T) { } } -func TestInsertMulti(t *testing.T) { +func insertMulti(t *testing.T) { user1 := Userinfo{Username: "xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()} user2 := Userinfo{Username: "xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()} _, err := engine.Insert(&user1, &user2) @@ -98,7 +90,7 @@ func TestInsertMulti(t *testing.T) { } } -func TestUpdate(t *testing.T) { +func update(t *testing.T) { // update by id user := Userinfo{Uid: 1, Username: "xxx"} _, err := engine.Update(&user) @@ -107,7 +99,7 @@ func TestUpdate(t *testing.T) { } } -func TestDelete(t *testing.T) { +func delete(t *testing.T) { user := Userinfo{Uid: 1} _, err := engine.Delete(&user) if err != nil { @@ -115,7 +107,7 @@ func TestDelete(t *testing.T) { } } -func TestGet(t *testing.T) { +func get(t *testing.T) { user := Userinfo{Uid: 2} err := engine.Get(&user) @@ -125,7 +117,7 @@ func TestGet(t *testing.T) { fmt.Println(user) } -func TestFind(t *testing.T) { +func find(t *testing.T) { users := make([]Userinfo, 0) err := engine.Find(&users) @@ -135,8 +127,8 @@ func TestFind(t *testing.T) { fmt.Println(users) } -func TestCount(t *testing.T) { - user := Userinfo{} +func count(t *testing.T) { + user := Userinfo{Departname: "dev"} total, err := engine.Count(&user) if err != nil { t.Error(err) @@ -144,47 +136,117 @@ func TestCount(t *testing.T) { fmt.Printf("Total %d records!!!", total) } -func TestWhere(t *testing.T) { +func where(t *testing.T) { users := make([]Userinfo, 0) - session, err := engine.MakeSession() - defer session.Close() - if err != nil { - t.Error(err) - } - err = session.Where("id > ?", 2).Find(&users) + err := engine.Where("id > ?", 2).Find(&users) if err != nil { t.Error(err) } fmt.Println(users) } -func TestLimit(t *testing.T) { +func limit(t *testing.T) { users := make([]Userinfo, 0) - session, err := engine.MakeSession() - defer session.Close() - if err != nil { - t.Error(err) - } - err = session.Limit(2, 1).Find(&users) + err := engine.Limit(2, 1).Find(&users) if err != nil { t.Error(err) } fmt.Println(users) } -func TestOrder(t *testing.T) { +func order(t *testing.T) { users := make([]Userinfo, 0) - session, err := engine.MakeSession() - defer session.Close() - if err != nil { - t.Error(err) - } - err = session.OrderBy("id desc").Find(&users) + err := engine.OrderBy("id desc").Find(&users) if err != nil { t.Error(err) } fmt.Println(users) } -func TestTransaction(*testing.T) { +func transaction(t *testing.T) { + counter := func() { + total, err := engine.Count(&Userinfo{}) + if err != nil { + t.Error(err) + } + fmt.Printf("----now total %v records\n", total) + } + + counter() + + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + return + } + + defer counter() + + session.Begin() + session.IsAutoRollback = true + user1 := Userinfo{Username: "xiaoxiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + _, err = session.Insert(&user1) + if err != nil { + t.Error(err) + return + } + user2 := Userinfo{Username: "yyy"} + _, err = session.Where("id = ?", 2).Update(&user2) + if err != nil { + t.Error(err) + return + } + + _, err = session.Delete(&user2) + if err != nil { + t.Error(err) + return + } + + err = session.Commit() + if err != nil { + t.Error(err) + return + } +} + +func TestMysql(t *testing.T) { + engine = xorm.Create("mysql://root:123@localhost/test") + engine.ShowSQL = true + + directCreateTable(t) + mapper(t) + insert(t) + insertAutoIncr(t) + insertMulti(t) + update(t) + delete(t) + get(t) + find(t) + count(t) + where(t) + limit(t) + order(t) + transaction(t) +} + +func TestSqlite(t *testing.T) { + engine = xorm.Create("sqlite:///test.db") + engine.ShowSQL = true + + directCreateTable(t) + mapper(t) + insert(t) + insertAutoIncr(t) + insertMulti(t) + update(t) + delete(t) + get(t) + find(t) + count(t) + where(t) + limit(t) + order(t) + transaction(t) }