diff --git a/README.md b/README.md index 730f890a..5cb7eab6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Xorm is a simple and powerful ORM for Go. -[![Build Status](https://drone.io/github.com/go-xorm/xorm/status.png)](https://drone.io/github.com/go-xorm/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") +[![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") # Features @@ -37,8 +37,12 @@ Drivers for Go's sql package which currently support database/sql includes: * Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) + * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) + + # Changelog * **v0.4.0 RC1** diff --git a/README_CN.md b/README_CN.md index f3974aa3..8aee186a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -4,7 +4,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 -[![Build Status](https://drone.io/github.com/go-xorm/xorm/status.png)](https://drone.io/github.com/go-xorm/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) +[![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/xorm/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) ## 特性 @@ -38,6 +38,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) + * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) ## 更新日志 diff --git a/docs/QuickStart.md b/docs/QuickStart.md index 47aa98dc..6f568071 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -91,7 +91,7 @@ f, err := os.Create("sql.log") println(err.Error()) return } -engine.Logger = f +engine.Logger = xorm.NewSimpleLogger(f) ``` 3.Engine provide DB connection pool settings. diff --git a/docs/QuickStartCN.md b/docs/QuickStartCN.md index 366a6cad..d0fd45b1 100644 --- a/docs/QuickStartCN.md +++ b/docs/QuickStartCN.md @@ -95,7 +95,7 @@ f, err := os.Create("sql.log") println(err.Error()) return } -engine.Logger = f +engine.Logger = xorm.NewSimpleLogger(f) ``` 3.engine内部支持连接池接口。 diff --git a/engine.go b/engine.go index fe09d82b..74308f7f 100644 --- a/engine.go +++ b/engine.go @@ -469,6 +469,13 @@ func (engine *Engine) Incr(column string, arg ...interface{}) *Session { return session.Incr(column, arg...) } +// Method Decr provides a update string like "column = column - ?" +func (engine *Engine) Decr(column string, arg ...interface{}) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Decr(column, arg...) +} + // Temporarily change the Get, Find, Update's table func (engine *Engine) Table(tableNameOrBean interface{}) *Session { session := engine.NewSession() @@ -1110,21 +1117,21 @@ func (engine *Engine) Sync2(beans ...interface{}) error { if engine.dialect.DBType() == core.MYSQL { _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) } else { - engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", - table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + engine.LogWarn(fmt.Sprintf("Table %s Column %s db type is %s, struct type is %s\n", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name)) } } else { - engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", - table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + engine.LogWarn(fmt.Sprintf("Table %s Column %s db type is %s, struct type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name)) } } if col.Default != oriCol.Default { - engine.LogWarn("Table %s Column %s Old default is %s, new default is %s", - table.Name, col.Name, oriCol.Default, col.Default) + engine.LogWarn(fmt.Sprintf("Table %s Column %s db default is %s, struct default is %s", + table.Name, col.Name, oriCol.Default, col.Default)) } if col.Nullable != oriCol.Nullable { - engine.LogWarn("Table %s Column %s Old nullable is %v, new nullable is %v", - table.Name, col.Name, oriCol.Nullable, col.Nullable) + engine.LogWarn(fmt.Sprintf("Table %s Column %s db nullable is %v, struct nullable is %v", + table.Name, col.Name, oriCol.Nullable, col.Nullable)) } } else { session := engine.NewSession() @@ -1430,6 +1437,8 @@ func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{} case core.TimeStampz: if engine.dialect.DBType() == core.MSSQL { v = engine.TZTime(t).Format("2006-01-02T15:04:05.9999999Z07:00") + } else if engine.DriverName() == "mssql" { + v = engine.TZTime(t) } else { v = engine.TZTime(t).Format(time.RFC3339Nano) } diff --git a/postgres_dialect.go b/postgres_dialect.go index a088664c..61d4f1e2 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -46,6 +46,8 @@ func (db *postgres) SqlType(c *core.Column) string { res = core.Real case core.TinyText, core.MediumText, core.LongText: res = core.Text + case core.Uuid: + res = core.Uuid case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: return core.Bytea case core.Double: @@ -143,8 +145,17 @@ func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, err func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} - s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + - ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix , + CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, + CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey +FROM pg_attribute f + JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid + LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum + LEFT JOIN pg_namespace n ON n.oid = c.relnamespace + LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) + LEFT JOIN pg_class AS g ON p.confrelid = g.oid + LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name +WHERE c.relkind = 'r'::char AND c.relname = $1 AND f.attnum > 0 ORDER BY f.attnum;` rows, err := db.DB().Query(s, args...) if err != nil { @@ -161,11 +172,12 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col var colName, isNullable, dataType string var maxLenStr, colDefault, numPrecision, numRadix *string - err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &numPrecision, &numRadix) + var isPK, isUnique bool + err = rows.Scan(&colName, &colDefault, &isNullable, &dataType, &maxLenStr, &numPrecision, &numRadix, &isPK, &isUnique) if err != nil { return nil, nil, err } - + //fmt.Println(args,colName, isNullable, dataType,maxLenStr, colDefault, numPrecision, numRadix,isPK ,isUnique) var maxLen int if maxLenStr != nil { maxLen, err = strconv.Atoi(*maxLenStr) @@ -176,8 +188,8 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col col.Name = strings.Trim(colName, `" `) - if colDefault != nil { - if strings.HasPrefix(*colDefault, "nextval") { + if colDefault != nil || isPK { + if isPK { col.IsPrimaryKey = true } else { col.Default = *colDefault diff --git a/pq_driver.go b/pq_driver.go index a5bb6718..c8dd5aa0 100644 --- a/pq_driver.go +++ b/pq_driver.go @@ -3,6 +3,8 @@ package xorm import ( "errors" "fmt" + "net/url" + "sort" "strings" "github.com/go-xorm/core" @@ -29,6 +31,53 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +func parseURL(connstr string) (string, error) { + u, err := url.Parse(connstr) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"="+escaper.Replace(v)) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + i := strings.Index(u.Host, ":") + if i < 0 { + accrue("host", u.Host) + } else { + accrue("host", u.Host[:i]) + accrue("port", u.Host[i+1:]) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} + func parseOpts(name string, o values) { if len(name) == 0 { return @@ -49,6 +98,13 @@ func parseOpts(name string, o values) { func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { db := &core.Uri{DbType: core.POSTGRES} o := make(values) + var err error + if strings.HasPrefix(dataSourceName, "postgres://") { + dataSourceName, err = parseURL(dataSourceName) + if err != nil { + return nil, err + } + } parseOpts(dataSourceName, o) db.DbName = o.Get("dbname") diff --git a/session.go b/session.go index b1636f4e..f0a345b2 100644 --- a/session.go +++ b/session.go @@ -39,7 +39,8 @@ type Session struct { beforeClosures []func(interface{}) afterClosures []func(interface{}) - stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) + cascadeDeep int } // Method Init reset the session as the init status. @@ -145,6 +146,12 @@ func (session *Session) Incr(column string, arg ...interface{}) *Session { return session } +// Method Decr provides a query string like "count = count - 1" +func (session *Session) Decr(column string, arg ...interface{}) *Session { + session.Statement.Decr(column, arg...) + return session +} + // Method Cols provides some columns to special func (session *Session) Cols(columns ...string) *Session { session.Statement.Cols(columns...) @@ -389,7 +396,8 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b for key, data := range objMap { if col = table.GetColumn(key); col == nil { - session.Engine.LogWarn(fmt.Sprintf("table %v's has not column %v. %v", table.Name, key, table.Columns())) + session.Engine.LogWarn(fmt.Sprintf("struct %v's has not field %v. %v", + table.Type.Name(), key, table.ColumnsSeq())) continue } @@ -895,25 +903,23 @@ func (session *Session) Iterate(bean interface{}, fun IterFunc) error { rows, err := session.Rows(bean) if err != nil { return err - } else { - defer rows.Close() - //b := reflect.New(iterator.beanType).Interface() - i := 0 - for rows.Next() { - b := reflect.New(rows.beanType).Interface() - err = rows.Scan(b) - if err != nil { - return err - } - err = fun(i, b) - if err != nil { - return err - } - i++ - } - return err } - return nil + defer rows.Close() + //b := reflect.New(iterator.beanType).Interface() + i := 0 + for rows.Next() { + b := reflect.New(rows.beanType).Interface() + err = rows.Scan(b) + if err != nil { + return err + } + err = fun(i, b) + if err != nil { + return err + } + i++ + } + return err } func (session *Session) doPrepare(sqlStr string) (stmt *core.Stmt, err error) { @@ -2451,6 +2457,38 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } fieldValue.Set(reflect.ValueOf(&x)) default: + if fieldType.Elem().Kind() == reflect.Struct { + if session.Statement.UseCascade { + structInter := reflect.New(fieldType.Elem()) + fmt.Println(structInter, fieldType.Elem()) + table := session.Engine.autoMapType(structInter.Elem()) + if table != nil { + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + if x != 0 { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(x).Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v = structInter.Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } + } + } + } else { + return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) + } + } return fmt.Errorf("unsupported type in Scan: %s", reflect.TypeOf(v).String()) } default: @@ -2565,6 +2603,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } else { return nil, ErrUnSupportedType } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + return int64(fieldValue.Uint()), nil default: return fieldValue.Interface(), nil } @@ -3009,6 +3049,13 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" + ?") args = append(args, v.arg) } + //for update action to like "column = column - ?" + decColumns := session.Statement.getDec() + for _, v := range decColumns { + colNames = append(colNames, session.Engine.Quote(v.colName)+" = "+session.Engine.Quote(v.colName)+" - ?") + args = append(args, v.arg) + } + var condiColNames []string var condiArgs []interface{} diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index 0626cf4e..ddf6a5f2 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -129,7 +129,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu } nStart := strings.Index(name, "(") - nEnd := strings.Index(name, ")") + nEnd := strings.LastIndex(name, ")") colCreates := strings.Split(name[nStart+1:nEnd], ",") cols := make(map[string]*core.Column) colSeq := make([]string, 0) diff --git a/statement.go b/statement.go index 33aa20a7..9286250d 100644 --- a/statement.go +++ b/statement.go @@ -20,6 +20,11 @@ type incrParam struct { arg interface{} } +type decrParam struct { + colName string + arg interface{} +} + // statement save all the sql info for executing SQL type Statement struct { RefTable *core.Table @@ -54,6 +59,7 @@ type Statement struct { mustColumnMap map[string]bool inColumns map[string]*inParam incrColumns map[string]incrParam + decrColumns map[string]decrParam } // init @@ -85,6 +91,7 @@ func (statement *Statement) Init() { statement.checkVersion = true statement.inColumns = make(map[string]*inParam) statement.incrColumns = make(map[string]incrParam) + statement.decrColumns = make(map[string]decrParam) } // add the raw sql statement @@ -375,7 +382,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if !requiredField && fieldValue.Uint() == 0 { continue } - val = fieldValue.Interface() + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) @@ -546,7 +554,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if !requiredField && fieldValue.Uint() == 0 { continue } - val = fieldValue.Interface() + t := int64(fieldValue.Uint()) + val = reflect.ValueOf(&t).Interface() case reflect.Struct: if fieldType == reflect.TypeOf(time.Now()) { t := fieldValue.Interface().(time.Time) @@ -674,11 +683,27 @@ func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { return statement } +// Generate "Update ... Set column = column - arg" statment +func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { + k := strings.ToLower(column) + if len(arg) > 0 { + statement.decrColumns[k] = decrParam{column, arg[0]} + } else { + statement.decrColumns[k] = decrParam{column, 1} + } + return statement +} + // Generate "Update ... Set column = column + arg" statment func (statement *Statement) getInc() map[string]incrParam { return statement.incrColumns } +// Generate "Update ... Set column = column - arg" statment +func (statement *Statement) getDec() map[string]decrParam { + return statement.decrColumns +} + // Generate "Where column IN (?) " statment func (statement *Statement) In(column string, args ...interface{}) *Statement { k := strings.ToLower(column) @@ -833,7 +858,7 @@ func (statement *Statement) OrderBy(order string) *Statement { //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(join_operator, tablename, condition string) *Statement { if statement.JoinStr != "" { - statement.JoinStr = statement.JoinStr + fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) + statement.JoinStr = statement.JoinStr + fmt.Sprintf(" %v JOIN %v ON %v", join_operator, tablename, condition) } else { statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, tablename, condition) } diff --git a/xorm.go b/xorm.go index 86a4ee55..bf4aa8f6 100644 --- a/xorm.go +++ b/xorm.go @@ -1,7 +1,6 @@ package xorm import ( - "database/sql" "errors" "fmt" "os" @@ -17,41 +16,29 @@ const ( Version string = "0.4" ) -// !nashtsai! implicit register drivers and dialects is no good, as init() can be called before sql driver got registered -// func init() { -// regDrvsNDialects() -// } - func regDrvsNDialects() bool { - if core.RegisteredDriverSize() == 0 { - providedDrvsNDialects := map[string]struct { - dbType core.DbType - getDriver func() core.Driver - getDialect func() core.Dialect - }{ - "odbc": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access - "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, - "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, - "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, - "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, - "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, - "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, - } - - for driverName, v := range providedDrvsNDialects { - _, err := sql.Open(driverName, "") - if err == nil { - // fmt.Printf("driver succeed: %v\n", driverName) - core.RegisterDriver(driverName, v.getDriver()) - core.RegisterDialect(v.dbType, v.getDialect()) - } else { - // fmt.Printf("driver failed: %v | err: %v\n", driverName, err) - } - } - return true - } else { - return false + providedDrvsNDialects := map[string]struct { + dbType core.DbType + getDriver func() core.Driver + getDialect func() core.Dialect + }{ + "mssql": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, + "odbc": {"mssql", func() core.Driver { return &odbcDriver{} }, func() core.Dialect { return &mssql{} }}, // !nashtsai! TODO change this when supporting MS Access + "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, + "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, + "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, + "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, + "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, + "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }}, } + + for driverName, v := range providedDrvsNDialects { + if driver := core.QueryDriver(driverName); driver == nil { + core.RegisterDriver(driverName, v.getDriver()) + core.RegisterDialect(v.dbType, v.getDialect()) + } + } + return true } func close(engine *Engine) {