diff --git a/README.md b/README.md index 0a48af25..b1b5b6ec 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [中文](https://github.com/lunny/xorm/blob/master/README_CN.md) -xorm is an ORM for Go. It makes dabatabse operating simple. +xorm is a simple and powerful ORM for Go. It makes dabatabse operating simple. It's not entirely ready for product use yet, but it's getting there. @@ -16,6 +16,7 @@ Drivers for Go's sql package which currently support database/sql includes: ## Changelog +* **v0.1.4** : Added simple cascade load support; added more data type supports. * **v0.1.3** : Find function now supports both slice and map; Add Table function for multi tables and temperory tables support * **v0.1.2** : Insert function now supports both struct and slice pointer parameters, batch inserting and auto transaction * **v0.1.1** : Add Id, In functions and improved README @@ -31,7 +32,10 @@ Drivers for Go's sql package which currently support database/sql includes: * Simply usage -* Support Id, In, Where, Limit, Join, Having functions and sturct as query conditions +* Support Id, In, Where, Limit, Join, Having functions and sturct as query conditions + +* Support simple cascade load just like Hibernate for Java + ## Installing xorm @@ -111,11 +115,14 @@ var user = User{Name:"xlw"} err := engine.Get(&user) ``` -6.Fetch multipe objects, use Find: +6.Fetch multipe objects into a slice or a map, use Find: ```Go var everyone []Userinfo err := engine.Find(&everyone) + +users := make(map[int64]Userinfo) +err := engine.Find(&users) ``` 6.1 also you can use Where, Limit @@ -293,7 +300,7 @@ Another is use field tag, field tag support the below keywords which split with pkthe field is a primary key - int(11)/varchar(50)column type + int(11)/varchar(50)/text/date/datetime/blob/decimal(26,2)column type autoincrauto incrment diff --git a/README_CN.md b/README_CN.md index 4ecbe2de..385dcfc4 100644 --- a/README_CN.md +++ b/README_CN.md @@ -16,6 +16,7 @@ xorm是一个Go语言的ORM库. 通过它可以使数据库操作非常简便。 ## 更新日志 +* **v0.1.4** : Get函数和Find函数新增简单的级联载入功能;对更多的数据库类型支持。 * **v0.1.3** : Find函数现在支持传入Slice或者Map,当传入Map时,key为id;新增Table函数以为多表和临时表进行支持。 * **v0.1.2** : Insert函数支持混合struct和slice指针传入,并根据数据库类型自动批量插入,同时自动添加事务 * **v0.1.1** : 添加 Id, In 函数,改善 README 文档 @@ -108,11 +109,14 @@ var user = User{Name:"xlw"} err := engine.Get(&user) ``` -6.获取多个对象,可以用Find方法: +6.获取多个对象到一个Slice或一个Map对象中,可以用Find方法: ```Go var everyone []Userinfo err := engine.Find(&everyone) + +users := make(map[int64]Userinfo) +err := engine.Find(&users) ``` 6.1 你也可以使用Where和Limit方法设定条件和查询数量 @@ -289,7 +293,7 @@ UserInfo中的成员UserName将会自动对应名为user_name的字段。 pk是否是Primary Key - int(11)/varchar(50)字段类型 + int(11)/varchar(50)/text/date/datetime/blob/decimal(26,2)字段类型 autoincr是否是自增 diff --git a/engine.go b/engine.go index 1d67c0fc..3c68cb88 100644 --- a/engine.go +++ b/engine.go @@ -16,11 +16,16 @@ const ( MYMYSQL = "mymysql" ) +type dialect interface { + SqlType(t *Column) string +} + type Engine struct { Mapper IMapper TagIdentifier string DriverName string DataSourceName string + Dialect dialect Tables map[reflect.Type]Table AutoIncrement string ShowSQL bool @@ -119,8 +124,6 @@ func (engine *Engine) AutoMap(bean interface{}) *Table { func (engine *Engine) MapType(t reflect.Type) Table { table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} table.Columns = make(map[string]Column) - var pkCol *Column = nil - var pkstr = "" for i := 0; i < t.NumField(); i++ { tag := t.Field(i).Tag @@ -129,10 +132,11 @@ func (engine *Engine) MapType(t reflect.Type) Table { fieldType := t.Field(i).Type if ormTagStr != "" { - col = Column{FieldName: t.Field(i).Name, Nullable: true} + col = Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, + IsAutoIncrement: false} ormTagStr = strings.ToLower(ormTagStr) tags := strings.Split(ormTagStr, " ") - // TODO: + // TODO: if len(tags) > 0 { if tags[0] == "-" { continue @@ -143,7 +147,6 @@ func (engine *Engine) MapType(t reflect.Type) Table { case k == "pk": col.IsPrimaryKey = true col.Nullable = false - pkCol = &col case k == "null": col.Nullable = (tags[j-1] != "not") case k == "autoincr": @@ -158,7 +161,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { col.Length, _ = strconv.Atoi(lens) case strings.HasPrefix(k, "varchar"): col.SQLType = Varchar - lens := k[len("decimal")+1 : len(k)-1] + lens := k[len("varchar")+1 : len(k)-1] col.Length, _ = strconv.Atoi(lens) case strings.HasPrefix(k, "decimal"): col.SQLType = Decimal @@ -168,6 +171,10 @@ func (engine *Engine) MapType(t reflect.Type) Table { col.Length2, _ = strconv.Atoi(twolen[1]) case k == "date": col.SQLType = Date + case k == "datetime": + col.SQLType = DateTime + case k == "timestamp": + col.SQLType = TimeStamp case k == "not": default: if k != col.Default { @@ -189,31 +196,23 @@ func (engine *Engine) MapType(t reflect.Type) Table { if col.Name == "" { col.Name = engine.Mapper.Obj2Table(t.Field(i).Name) } + if col.IsPrimaryKey { + table.PrimaryKey = col.Name + } } - } - - if col.Name == "" { + } else { sqlType := Type2SQLType(fieldType) col = Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, sqlType.DefaultLength2, true, "", false, false, false} + + if col.Name == "id" { + col.IsPrimaryKey = true + col.IsAutoIncrement = true + col.Nullable = false + table.PrimaryKey = col.Name + } } table.Columns[col.Name] = col - if strings.ToLower(t.Field(i).Name) == "id" { - pkstr = col.Name - } - } - - if pkCol == nil { - if pkstr != "" { - col := table.Columns[pkstr] - col.IsPrimaryKey = true - col.IsAutoIncrement = true - col.Nullable = false - col.Length = Int.DefaultLength - table.PrimaryKey = col.Name - } - } else { - table.PrimaryKey = pkCol.Name } return table diff --git a/mysql.go b/mysql.go new file mode 100644 index 00000000..7df2f680 --- /dev/null +++ b/mysql.go @@ -0,0 +1,19 @@ +package xorm + +import "strconv" + +type mysql struct { +} + +func (db mysql) SqlType(c *Column) string { + switch t := c.SQLType; t { + case Date, DateTime, TimeStamp: + return "DATETIME" + case Varchar: + return t.Name + "(" + strconv.Itoa(c.Length) + ")" + case Decimal: + return t.Name + "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + default: + return t.Name + } +} diff --git a/session.go b/session.go index 90cce292..ef5c3881 100644 --- a/session.go +++ b/session.go @@ -59,6 +59,13 @@ func (session *Session) OrderBy(order string) *Session { return session } +func (session *Session) Cascade(trueOrFalse ...bool) *Session { + if len(trueOrFalse) >= 1 { + session.Statement.UseCascade = trueOrFalse[0] + } + return session +} + //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (session *Session) Join(join_operator, tablename, condition string) *Session { session.Statement.Join(join_operator, tablename, condition) @@ -130,6 +137,10 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b switch structField.Type().Kind() { case reflect.Slice: v = data + case reflect.Array: + if structField.Type().Elem() == reflect.TypeOf(b) { + v = data + } case reflect.String: v = string(data) case reflect.Bool: @@ -160,20 +171,44 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b v = x //Now only support Time type case reflect.Struct: - if structField.Type().String() != "time.Time" { - return errors.New("unsupported struct type in Scan: " + structField.Type().String()) - } - - x, err := time.Parse("2006-01-02 15:04:05", string(data)) - if err != nil { - x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data)) - + if structField.Type().String() == "time.Time" { + x, err := time.Parse("2006-01-02 15:04:05", string(data)) if err != nil { - return errors.New("unsupported time format: " + string(data)) + x, err = time.Parse("2006-01-02 15:04:05.000 -0700", string(data)) + + if err != nil { + return errors.New("unsupported time format: " + string(data)) + } + } + + v = x + } else if session.Statement.UseCascade { + session.Engine.AutoMapType(structField.Type()) + if _, ok := session.Engine.Tables[structField.Type()]; ok { + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return errors.New("arg " + key + " as int: " + err.Error()) + } + + if x != 0 { + structInter := reflect.New(structField.Type()) + session.Statement.Init() + err = session.Id(x).Get(structInter.Interface()) + if err != nil { + return err + } + + v = structInter.Elem().Interface() + } else { + //fmt.Println("zero value of struct type " + structField.Type().String()) + continue + } + + } else { + fmt.Println("unsupported struct type in Scan: " + structField.Type().String()) + continue } } - - v = x default: return errors.New("unsupported type in Scan: " + reflect.TypeOf(v).String()) } @@ -205,6 +240,7 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error } if session.Engine.ShowSQL { fmt.Println(sql) + fmt.Println(args) } if session.IsAutoCommit { return session.innerExec(sql, args...) @@ -226,6 +262,8 @@ func (session *Session) Get(bean interface{}) error { defer statement.Init() statement.Limit(1) + fmt.Println(bean) + sql, args := statement.genGetSql(bean) resultsSlice, err := session.Query(sql, args...) @@ -321,6 +359,7 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice } if session.Engine.ShowSQL { fmt.Println(sql) + fmt.Println(paramStr) } s, err := session.Db.Prepare(sql) if err != nil { @@ -375,11 +414,14 @@ func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice case reflect.String: str = vv.String() result[key] = []byte(str) - //时间类型 + //时间类型 case reflect.Struct: - str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") - result[key] = []byte(str) + if aa.String() == "time.Time" { + str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700") + result[key] = []byte(str) + } } + //default: } resultsSlice = append(resultsSlice, result) @@ -465,7 +507,13 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { if col.IsAutoIncrement && fieldValue.Int() == 0 { continue } - args = append(args, val) + if table, ok := session.Engine.Tables[fieldValue.Type()]; ok { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) + fmt.Println(pkField.Interface()) + args = append(args, pkField.Interface()) + } else { + args = append(args, val) + } colNames = append(colNames, col.Name) cols = append(cols, col) colPlaces = append(colPlaces, "?") @@ -477,7 +525,12 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { if col.IsAutoIncrement && fieldValue.Int() == 0 { continue } - args = append(args, val) + if table, ok := session.Engine.Tables[fieldValue.Type()]; ok { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) + args = append(args, pkField.Interface()) + } else { + args = append(args, val) + } colPlaces = append(colPlaces, "?") } } @@ -517,7 +570,12 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { if col.IsAutoIncrement && fieldValue.Int() == 0 { continue } - args = append(args, val) + if table, ok := session.Engine.Tables[fieldValue.Type()]; ok { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) + args = append(args, pkField.Interface()) + } else { + args = append(args, val) + } colNames = append(colNames, col.Name) colPlaces = append(colPlaces, "?") } diff --git a/sqlite3.go b/sqlite3.go new file mode 100644 index 00000000..9da9811c --- /dev/null +++ b/sqlite3.go @@ -0,0 +1,23 @@ +package xorm + +type sqlite3 struct { +} + +func (db sqlite3) SqlType(c *Column) string { + switch t := c.SQLType; t { + case Date, DateTime, TimeStamp: + return "NUMERIC" + case Char, Varchar, Text: + return "TEXT" + case TinyInt, SmallInt, MediumInt, Int, BigInt: + return "INTEGER" + case Float, Double: + return "REAL" + case Decimal: + return "NUMERIC" + case Blob: + return "BLOB" + default: + return t.Name + } +} diff --git a/statement.go b/statement.go index c80c0cf6..dc18aa90 100644 --- a/statement.go +++ b/statement.go @@ -3,7 +3,7 @@ package xorm import ( "fmt" "reflect" - "strconv" + //"strconv" "strings" "time" ) @@ -21,6 +21,7 @@ type Statement struct { HavingStr string ColumnStr string AltTableName string + UseCascade bool BeanArgs []interface{} } @@ -39,6 +40,7 @@ func (statement *Statement) Init() { statement.WhereStr = "" statement.Params = make([]interface{}, 0) statement.OrderStr = "" + statement.UseCascade = true statement.JoinStr = "" statement.GroupByStr = "" statement.HavingStr = "" @@ -82,7 +84,17 @@ func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, default: continue } - args = append(args, val) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) + fmt.Println(pkField.Interface()) + if pkField.Int() != 0 { + args = append(args, pkField.Interface()) + } else { + continue + } + } else { + args = append(args, val) + } colNames = append(colNames, engine.QuoteIdentifier+col.Name+engine.QuoteIdentifier+"=?") } @@ -150,38 +162,30 @@ func (statement *Statement) Having(conditions string) { func (statement *Statement) genColumnStr(col *Column) string { sql := "`" + col.Name + "` " - if col.SQLType == Date { - sql += " datetime " - } else { - if statement.Engine.DriverName == SQLITE && col.IsPrimaryKey { - sql += "integer" - } else { - sql += col.SQLType.Name - } - if statement.Engine.DriverName != SQLITE && col.SQLType != Text { - if col.SQLType != Decimal { - sql += "(" + strconv.Itoa(col.Length) + ")" - } else { - sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")" - } - } - } - if col.Nullable { - sql += " NULL " - } else { - sql += " NOT NULL " - } - //fmt.Println(key) + sql += statement.Engine.Dialect.SqlType(col) + " " + if col.IsPrimaryKey { sql += "PRIMARY KEY " } + if col.IsAutoIncrement { sql += statement.Engine.AutoIncrement + " " } + + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + if col.IsUnique { sql += "Unique " } + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } return sql } @@ -198,7 +202,8 @@ func (statement *Statement) genCreateSQL() string { sql := "CREATE TABLE IF NOT EXISTS `" + statement.TableName() + "` (" for _, col := range statement.RefTable.Columns { sql += statement.genColumnStr(&col) - sql += "," + sql = strings.TrimSpace(sql) + sql += ", " } sql = sql[:len(sql)-2] + ");" return sql diff --git a/table.go b/table.go index d79e6f86..81667b38 100644 --- a/table.go +++ b/table.go @@ -2,7 +2,7 @@ package xorm import ( "reflect" - "strconv" + //"strconv" //"strings" "time" ) @@ -14,35 +14,49 @@ type SQLType struct { } var ( - Int = SQLType{"int", 11, 0} - Char = SQLType{"char", 1, 0} - Bool = SQLType{"int", 1, 0} - Varchar = SQLType{"varchar", 50, 0} - Text = SQLType{"text", 16, 0} - Date = SQLType{"date", 24, 0} - Decimal = SQLType{"decimal", 26, 2} - Float = SQLType{"float", 31, 0} - Double = SQLType{"double", 31, 0} + TinyInt = SQLType{"TINYINT", 0, 0} + SmallInt = SQLType{"SMALLINT", 0, 0} + MediumInt = SQLType{"MEDIUMINT", 0, 0} + Int = SQLType{"INT", 11, 0} + BigInt = SQLType{"BIGINT", 0, 0} + Char = SQLType{"CHAR", 1, 0} + Varchar = SQLType{"VARCHAR", 64, 0} + Text = SQLType{"TEXT", 16, 0} + Date = SQLType{"DATE", 24, 0} + DateTime = SQLType{"DATETIME", 0, 0} + Decimal = SQLType{"DECIMAL", 26, 2} + Float = SQLType{"FLOAT", 31, 0} + Double = SQLType{"DOUBLE", 31, 0} + Blob = SQLType{"BLOB", 0, 0} + TimeStamp = SQLType{"TIMESTAMP", 0, 0} ) -func (sqlType SQLType) genSQL(length int) string { - if sqlType == Date { - return " datetime " - } - return sqlType.Name + "(" + strconv.Itoa(length) + ")" -} +var b byte +var tm time.Time func Type2SQLType(t reflect.Type) (st SQLType) { switch k := t.Kind(); k { - case reflect.Int, reflect.Int32, reflect.Int64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: st = Int + case reflect.Int64, reflect.Uint64: + st = BigInt + case reflect.Float32: + st = Float + case reflect.Float64: + st = Double + case reflect.Complex64, reflect.Complex128: + st = Varchar + case reflect.Array, reflect.Slice: + if t.Elem() == reflect.TypeOf(b) { + st = Blob + } case reflect.Bool: - st = Bool + st = TinyInt case reflect.String: st = Varchar case reflect.Struct: - if t == reflect.TypeOf(time.Time{}) { - st = Date + if t == reflect.TypeOf(tm) { + st = DateTime } default: st = Varchar diff --git a/xorm.go b/xorm.go index 6c2d7911..53618481 100644 --- a/xorm.go +++ b/xorm.go @@ -13,17 +13,14 @@ func Create(driverName string, dataSourceName string) Engine { engine.InsertMany = true engine.TagIdentifier = "xorm" if driverName == SQLITE { + engine.Dialect = sqlite3{} engine.AutoIncrement = "AUTOINCREMENT" } else { + engine.Dialect = mysql{} engine.AutoIncrement = "AUTO_INCREMENT" } - if engine.DriverName == PQSQL { - engine.QuoteIdentifier = "\"" - } else if engine.DriverName == MSSQL { - engine.QuoteIdentifier = "" - } else { - engine.QuoteIdentifier = "`" - } + engine.QuoteIdentifier = "`" + return engine } diff --git a/xorm_test.go b/xorm_test.go index 81af8189..3a425374 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -4,6 +4,7 @@ import ( "fmt" _ "github.com/Go-SQL-Driver/MySQL" _ "github.com/mattn/go-sqlite3" + "os" "testing" "time" "xorm" @@ -26,15 +27,19 @@ CREATE TABLE `userdeatail` ( */ type Userinfo struct { - Uid int `xorm:"id pk not null autoincr"` + Uid int64 `xorm:"id pk not null autoincr"` Username string Departname string Alias string `xorm:"-"` Created time.Time + Detail Userdetail `xorm:"detail_id int(11)"` + Height float64 + Avatar []byte + IsMan bool } type Userdetail struct { - Uid int `xorm:"id pk not null"` + Id int64 Intro string `xorm:"text"` Profile string `xorm:"varchar(2000)"` } @@ -71,7 +76,8 @@ func mapper(t *testing.T) { } func insert(t *testing.T) { - user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now()} + user := Userinfo{1, "xiaolunwen", "dev", "lunny", time.Now(), + Userdetail{Id: 1}, 1.78, []byte{1, 2, 3}, true} _, err := engine.Insert(&user) if err != nil { t.Error(err) @@ -98,7 +104,8 @@ func exec(t *testing.T) { func insertAutoIncr(t *testing.T) { // auto increment insert - user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now()} + user := Userinfo{Username: "xiaolunwen", Departname: "dev", Alias: "lunny", Created: time.Now(), + Detail: Userdetail{Id: 1}, Height: 1.78, Avatar: []byte{1, 2, 3}, IsMan: true} _, err := engine.Insert(&user) if err != nil { t.Error(err) @@ -135,8 +142,9 @@ func insertMulti(t *testing.T) { } func insertTwoTable(t *testing.T) { - userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now()} - userdetail := Userdetail{Uid: 1, Intro: "I'm a very beautiful women.", Profile: "sfsaf"} + userdetail := Userdetail{Id: 1, Intro: "I'm a very beautiful women.", Profile: "sfsaf"} + userinfo := Userinfo{Username: "xlw3", Departname: "dev", Alias: "lunny4", Created: time.Now(), Detail: userdetail} + _, err := engine.Insert(&userinfo, &userdetail) if err != nil { t.Error(err) @@ -176,6 +184,16 @@ func get(t *testing.T) { fmt.Println(user) } +func cascadeGet(t *testing.T) { + user := Userinfo{Uid: 11} + + err := engine.Get(&user) + if err != nil { + t.Error(err) + } + fmt.Println(user) +} + func find(t *testing.T) { users := make([]Userinfo, 0) @@ -380,7 +398,7 @@ func createMultiTables(t *testing.T) { user := &Userinfo{} session.Begin() - for i := 0; i < 100; i++ { + for i := 0; i < 10; i++ { err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user) if err != nil { session.Rollback() @@ -425,6 +443,7 @@ func tableOp(t *testing.T) { } func TestMysql(t *testing.T) { + // You should drop all tables before executing this testing engine = xorm.Create("mysql", "root:123@/test?charset=utf8") engine.ShowSQL = true @@ -439,6 +458,7 @@ func TestMysql(t *testing.T) { update(t) delete(t) get(t) + cascadeGet(t) find(t) findMap(t) count(t) @@ -456,6 +476,7 @@ func TestMysql(t *testing.T) { } func TestSqlite(t *testing.T) { + os.Remove("./test.db") engine = xorm.Create("sqlite3", "./test.db") engine.ShowSQL = true @@ -470,6 +491,7 @@ func TestSqlite(t *testing.T) { update(t) delete(t) get(t) + cascadeGet(t) find(t) findMap(t) count(t)