diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fcbf9e31..4b806704 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,11 +21,11 @@ We appreciate any bug reports, but especially ones with self-contained further) test cases. It's especially helpful if you can submit a pull request with just the failing test case (you'll probably want to pattern it after the tests in -[base_test.go](https://github.com/go-xorm/xorm/blob/master/base_test.go) AND -[benchmark_base_test.go](https://github.com/go-xorm/xorm/blob/master/benchmark_base_test.go). +[base.go](https://github.com/go-xorm/tests/blob/master/base.go) AND +[benchmark.go](https://github.com/go-xorm/tests/blob/master/benchmark.go). If you implements a new database interface, you maybe need to add a _test.go file. -For example, [mysql_test.go](https://github.com/go-xorm/xorm/blob/master/mysql_test.go) +For example, [mysql_test.go](https://github.com/go-xorm/tests/blob/master/mysql/mysql_test.go) ### New functionality diff --git a/README.md b/README.md index 9c9ecfa7..730f890a 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ Xorm is a simple and powerful ORM for Go. * Query Cache speed up -* Database Reverse support, See [Xorm Tool README](https://github.com/go-xorm/xorm/blob/master/xorm/README.md) +* Database Reverse support, See [Xorm Tool README](https://github.com/go-xorm/cmd/blob/master/README.md) * Simple cascade loading support @@ -96,7 +96,7 @@ Or * [GoWalker](http://gowalker.org/github.com/go-xorm/xorm) -* [Quick Start](https://github.com/go-xorm/xorm/blob/master/docs/QuickStartEn.md) +* [Quick Start](https://github.com/go-xorm/xorm/blob/master/docs/QuickStart.md) # Cases @@ -123,13 +123,10 @@ Or Please visit [Xorm on Google Groups](https://groups.google.com/forum/#!forum/xorm) -# Contributors +# Contributing If you want to pull request, please see [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md) -* [Lunny](https://github.com/lunny) -* [Nashtsai](https://github.com/nashtsai) - # LICENSE BSD License diff --git a/README_CN.md b/README_CN.md index c8f3f180..f3974aa3 100644 --- a/README_CN.md +++ b/README_CN.md @@ -94,7 +94,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 ## 文档 -* [快速开始](https://github.com/go-xorm/xorm/blob/master/docs/QuickStart.md) +* [快速开始](https://github.com/go-xorm/xorm/blob/master/docs/QuickStartCN.md) * [GoWalker代码文档](http://gowalker.org/github.com/go-xorm/xorm) @@ -124,13 +124,10 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 请加入QQ群:280360085 进行讨论。 -# 贡献者 +## 贡献 如果您也想为Xorm贡献您的力量,请查看 [CONTRIBUTING](https://github.com/go-xorm/xorm/blob/master/CONTRIBUTING.md) -* [Lunny](https://github.com/lunny) -* [Nashtsai](https://github.com/nashtsai) - ## LICENSE BSD License diff --git a/docs/QuickStart.md b/docs/QuickStart.md index f240806d..2e1439c1 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -58,7 +58,7 @@ engine, err = xorm.NewEngine("sqlite3", "./test.db") defer engine.Close() ``` -Generally, you can only create one engine. Engine supports run on go rutines. +You can create many engines for different databases.Generally, you just need create only one engine. Engine supports run on go routines. xorm supports four drivers now: @@ -335,44 +335,43 @@ affected, err := engine.Insert(user, &questions) Notice: If you want to use transaction on inserting, you should use session.Begin() before calling Insert. -## 5.Query and count - -所有的查询条件不区分调用顺序,但必须在调用Get,Find,Count这三个函数之前调用。同时需要注意的一点是,在调用的参数中,所有的字符字段名均为映射后的数据库的字段名,而不是field的名字。 +## 5. Chainable APIs -### 5.1.查询条件方法 +### 5.1. Chainable APIs for Queries, Execusions and Aggregations +Queries and Aggregations is basically formed by using `Get`, `Find`, `Count` methods, with conjunction of following chainable APIs to form conditions, grouping and ordering: 查询和统计主要使用`Get`, `Find`, `Count`三个方法。在进行查询时可以使用多个方法来形成查询条件,条件函数如下: -* Id(int64) -传入一个PK字段的值,作为查询条件 +* Id([]interface{}) +Primary Key lookup * Where(string, …interface{}) -和Where语句中的条件基本相同,作为条件 +As SQL conditional WHERE clause * And(string, …interface{}) -和Where函数中的条件基本相同,作为条件 +Conditional AND * Or(string, …interface{}) -和Where函数中的条件基本相同,作为条件 +Conditional OR * Sql(string, …interface{}) 执行指定的Sql语句,并把结果映射到结构体 * Asc(…string) -指定字段名正序排序 +Ascending ordering on 1 or more fields * Desc(…string) -指定字段名逆序排序 +Descending ordering on 1 or more fields * OrderBy(string) -按照指定的顺序进行排序 +Custom ordering * In(string, …interface{}) -某字段在一些值中 +Conditional IN * Cols(…string) -只查询或更新某些指定的字段,默认是查询所有映射的字段或者根据Update的第一个参数来判断更新的字段。例如: +Explicity specify query or update columns. e.g.,: ```Go engine.Cols("age", "name").Find(&users) // SELECT age, name FROM user @@ -380,12 +379,10 @@ engine.Cols("age", "name").Update(&user) // UPDATE user SET age=? AND name=? ``` -其中的参数"age", "name"也可以写成"age, name",两种写法均可 - * Omit(...string) -和cols相反,此函数指定排除某些指定的字段。注意:此方法和Cols方法不可同时使用 +Inverse function to Cols, to exclude specify query or update columns. Warning: Don't use with Cols() ```Go -engine.Cols("age").Update(&user) +engine.Omit("age").Update(&user) // UPDATE user SET name = ? AND department = ? ``` diff --git a/docs/QuickStartCn.md b/docs/QuickStartCn.md index 2b76ca8e..607bdda6 100644 --- a/docs/QuickStartCn.md +++ b/docs/QuickStartCn.md @@ -23,7 +23,7 @@ xorm 快速入门 * [5.6.Count方法](#66) * [5.7.Rows方法](#67) * [6.更新数据](#70) -* [6.1.乐观锁](#71) + * [6.1.乐观锁](#71) * [7.删除数据](#80) * [8.执行SQL查询](#90) * [9.执行SQL命令](#100) @@ -62,7 +62,7 @@ engine, err = xorm.NewEngine("sqlite3", "./test.db") defer engine.Close() ``` -一般如果只针对一个数据库进行操作,只需要创建一个Engine即可。Engine支持在多GoRutine下使用。 +你可以创建一个或多个engine, 不过一般如果操作一个数据库,只需要创建一个Engine即可。Engine支持在多GoRutine下使用。 xorm当前支持五种驱动四个数据库如下: @@ -419,7 +419,7 @@ engine.Cols("age", "name").Update(&user) * Omit(...string) 和cols相反,此函数指定排除某些指定的字段。注意:此方法和Cols方法不可同时使用 ```Go -engine.Cols("age").Update(&user) +engine.Omit("age").Update(&user) // UPDATE user SET name = ? AND department = ? ``` diff --git a/engine.go b/engine.go index 603e43f6..2aaee105 100644 --- a/engine.go +++ b/engine.go @@ -6,6 +6,7 @@ import ( "database/sql" "errors" "fmt" + "io" "os" "reflect" "strconv" @@ -34,8 +35,7 @@ type Engine struct { ShowErr bool ShowDebug bool ShowWarn bool - //Pool IConnectPool - //Filters []core.Filter + Logger ILogger // io.Writer TZLocation *time.Location } @@ -266,6 +266,80 @@ func (engine *Engine) DBMetas() ([]*core.Table, error) { return tables, nil } +func (engine *Engine) DumpAllToFile(fp string) error { + f, err := os.Create(fp) + if err != nil { + return err + } + defer f.Close() + return engine.DumpAll(f) +} + +func (engine *Engine) DumpAll(w io.Writer) error { + tables, err := engine.DBMetas() + if err != nil { + return err + } + + for _, table := range tables { + _, err = io.WriteString(w, engine.dialect.CreateTableSql(table, "", "", "")+"\n\n") + if err != nil { + return err + } + for _, index := range table.Indexes { + _, err = io.WriteString(w, engine.dialect.CreateIndexSql(table.Name, index)+"\n\n") + if err != nil { + return err + } + } + + rows, err := engine.DB().Query("SELECT * FROM " + engine.Quote(table.Name)) + if err != nil { + return err + } + + cols, err := rows.Columns() + if err != nil { + return err + } + if len(cols) == 0 { + continue + } + for rows.Next() { + dest := make([]interface{}, len(cols)) + err = rows.ScanSlice(&dest) + if err != nil { + return err + } + + _, err = io.WriteString(w, "INSERT INTO "+engine.Quote(table.Name)+" ("+engine.Quote(strings.Join(cols, engine.Quote(", ")))+") VALUES (") + if err != nil { + return err + } + + var temp string + for i, d := range dest { + col := table.GetColumn(cols[i]) + if d == nil { + temp += ", NULL" + } else if col.SQLType.IsText() || col.SQLType.IsTime() { + var v = fmt.Sprintf("%s", d) + temp += ", '" + strings.Replace(v, "'", "''", -1) + "'" + } else if col.SQLType.IsBlob() /*reflect.TypeOf(d).Kind() == reflect.Slice*/ { + temp += fmt.Sprintf(", %s", engine.dialect.FormatBytes(d.([]byte))) + } else { + temp += fmt.Sprintf(", %s", d) + } + } + _, err = io.WriteString(w, temp[2:]+");\n\n") + if err != nil { + return err + } + } + } + return nil +} + // use cascade or not func (engine *Engine) Cascade(trueOrFalse ...bool) *Session { session := engine.NewSession() @@ -456,15 +530,6 @@ func (engine *Engine) autoMap(bean interface{}) *core.Table { return engine.autoMapType(v) } -/*func (engine *Engine) mapType(t reflect.Type) *core.Table { - return mappingTable(t, engine.TableMapper, engine.ColumnMapper, engine.dialect, engine.TagIdentifier) -}*/ - -/* -func mappingTable(t reflect.Type, tableMapper core.IMapper, colMapper core.IMapper, dialect core.Dialect, tagId string) *core.Table { - table := core.NewEmptyTable() - table.Name = tableMapper.Obj2Table(t.Name()) -*/ func addIndex(indexName string, table *core.Table, col *core.Column, indexType int) { if index, ok := table.Indexes[indexName]; ok { index.AddColumn(col.Name) @@ -524,17 +589,19 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { if tags[0] == "-" { continue } - if (strings.ToUpper(tags[0]) == "EXTENDS") && - (fieldType.Kind() == reflect.Struct) { + if strings.ToUpper(tags[0]) == "EXTENDS" { + fieldValue = reflect.Indirect(fieldValue) + if fieldValue.Kind() == reflect.Struct { + //parentTable := mappingTable(fieldType, tableMapper, colMapper, dialect, tagId) + parentTable := engine.mapType(fieldValue) + for _, col := range parentTable.Columns() { + col.FieldName = fmt.Sprintf("%v.%v", fieldValue.Type().Name(), col.FieldName) + table.AddColumn(col) + } - //parentTable := mappingTable(fieldType, tableMapper, colMapper, dialect, tagId) - parentTable := engine.mapType(fieldValue) - for _, col := range parentTable.Columns() { - col.FieldName = fmt.Sprintf("%v.%v", fieldType.Name(), col.FieldName) - table.AddColumn(col) + continue } - - continue + //TODO: warning } indexNames := make(map[string]int) @@ -599,20 +666,30 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { continue } col.SQLType = core.SQLType{fs[0], 0, 0} - fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") - if len(fs2) == 2 { - col.Length, err = strconv.Atoi(fs2[0]) - if err != nil { - engine.LogError(err) + if fs[0] == core.Enum && fs[1][0] == '\'' { //enum + options := strings.Split(fs[1][0:len(fs[1])-1], ",") + col.EnumOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.EnumOptions[v] = k } - col.Length2, err = strconv.Atoi(fs2[1]) - if err != nil { - engine.LogError(err) - } - } else if len(fs2) == 1 { - col.Length, err = strconv.Atoi(fs2[0]) - if err != nil { - engine.LogError(err) + } else { + fs2 := strings.Split(fs[1][0:len(fs[1])-1], ",") + if len(fs2) == 2 { + col.Length, err = strconv.Atoi(fs2[0]) + if err != nil { + engine.LogError(err) + } + col.Length2, err = strconv.Atoi(fs2[1]) + if err != nil { + engine.LogError(err) + } + } else if len(fs2) == 1 { + col.Length, err = strconv.Atoi(fs2[0]) + if err != nil { + engine.LogError(err) + } } } } else { @@ -701,19 +778,25 @@ func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) { session := engine.NewSession() defer session.Close() rows, err := session.Count(bean) - return rows > 0, err + return rows == 0, err } // If a table is exist func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { v := rValue(bean) - if v.Type().Kind() != reflect.Struct { + var tableName string + if v.Type().Kind() == reflect.String { + tableName = bean.(string) + } else if v.Type().Kind() == reflect.Struct { + table := engine.autoMapType(v) + tableName = table.Name + } else { return false, errors.New("bean should be a struct or struct's point") } - table := engine.autoMapType(v) + session := engine.NewSession() defer session.Close() - has, err := session.isTableExist(table.Name) + has, err := session.isTableExist(tableName) return has, err } diff --git a/examples/sync.go b/examples/sync.go index ad28ad80..d108e455 100644 --- a/examples/sync.go +++ b/examples/sync.go @@ -88,6 +88,17 @@ func main() { _, err = Orm.Insert(user) if err != nil { fmt.Println(err) + return + } + + isexist, err := Orm.IsTableExist("sync_user2") + if err != nil { + fmt.Println(err) + return + } + if !isexist { + fmt.Println("sync_user2 is not exist") + return } } } diff --git a/examples/tables.go b/examples/tables.go new file mode 100644 index 00000000..97c842be --- /dev/null +++ b/examples/tables.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "os" + + "github.com/go-xorm/xorm" + _ "github.com/mattn/go-sqlite3" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("need db path") + return + } + + orm, err := xorm.NewEngine("sqlite3", os.Args[1]) + if err != nil { + fmt.Println(err) + return + } + defer orm.Close() + orm.ShowSQL = true + + tables, err := orm.DBMetas() + if err != nil { + fmt.Println(err) + return + } + + for _, table := range tables { + fmt.Println(table.Name) + } +} diff --git a/mssql_dialect.go b/mssql_dialect.go index 42a3ae97..fd39cbec 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -211,6 +211,7 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err @@ -221,7 +222,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? var indexType int var indexName, colName, isUnique string - err = rows.Scan(&indexName, &colName, &isUnique, nil) + err = rows.Scan(&indexName, &colName, &isUnique) if err != nil { return nil, err } diff --git a/mysql_dialect.go b/mysql_dialect.go index 1ad86819..e76830ab 100644 --- a/mysql_dialect.go +++ b/mysql_dialect.go @@ -53,6 +53,17 @@ func (db *mysql) SqlType(c *core.Column) string { case core.TimeStampz: res = core.Char c.Length = 64 + case core.Enum: //mysql enum + res = core.Enum + res += "(" + for v, k := range c.EnumOptions { + if k > 0 { + res += fmt.Sprintf(",'%v'", v) + } else { + res += fmt.Sprintf("'%v'", v) + } + } + res += ")" default: res = t } @@ -140,26 +151,39 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column if colDefault != nil { col.Default = *colDefault + if col.Default == "" { + col.DefaultIsEmpty = true + } } cts := strings.Split(colType, "(") + colName := cts[0] + colType = strings.ToUpper(colName) var len1, len2 int if len(cts) == 2 { idx := strings.Index(cts[1], ")") - lens := strings.Split(cts[1][0:idx], ",") - len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) - if err != nil { - return nil, nil, err - } - if len(lens) == 2 { - len2, err = strconv.Atoi(lens[1]) + if colType == core.Enum && cts[1][0] == '\'' { //enum + options := strings.Split(cts[1][0:idx], ",") + col.EnumOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.EnumOptions[v] = k + } + } else { + lens := strings.Split(cts[1][0:idx], ",") + len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) if err != nil { return nil, nil, err } + if len(lens) == 2 { + len2, err = strconv.Atoi(lens[1]) + if err != nil { + return nil, nil, err + } + } } } - colName := cts[0] - colType = strings.ToUpper(colName) col.Length = len1 col.Length2 = len2 if _, ok := core.SqlTypes[colType]; ok { @@ -182,6 +206,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column if col.SQLType.IsText() { if col.Default != "" { col.Default = "'" + col.Default + "'" + } else { + if col.DefaultIsEmpty { + col.Default = "''" + } } } cols[col.Name] = col diff --git a/processors.go b/processors.go index 770515e6..03ae8e0f 100644 --- a/processors.go +++ b/processors.go @@ -15,6 +15,10 @@ type BeforeDeleteProcessor interface { BeforeDelete() } +type BeforeSetProcessor interface { + BeforeSet(string, Cell) +} + // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations //// Executed before an object is validated //type BeforeValidateProcessor interface { diff --git a/session.go b/session.go index 81207274..fe05f750 100644 --- a/session.go +++ b/session.go @@ -718,7 +718,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in if err != nil { return err } - // 查询数目太大,采用缓存将不是一个很好的方式。 + // 查询数目太大,采用缓存将不是一个很好的方式〠if len(resultsSlice) > 500 { session.Engine.LogDebug("[xorm:cacheFind] ids length %v > 500, no cache", len(resultsSlice)) return ErrCacheFailed @@ -883,7 +883,6 @@ func (session *Session) Rows(bean interface{}) (*Rows, error) { // are conditions. beans could be []Struct, []*Struct, map[int64]Struct // map[int64]*Struct func (session *Session) Iterate(bean interface{}, fun IterFunc) error { - rows, err := session.Rows(bean) if err != nil { return err @@ -982,24 +981,6 @@ func (session *Session) Get(bean interface{}) (bool, error) { } else { return false, nil } - - // resultsSlice, err := session.query(sqlStr, args...) - // if err != nil { - // return false, err - // } - // if len(resultsSlice) < 1 { - // return false, nil - // } - - // err = session.scanMapIntoStruct(bean, resultsSlice[0]) - // if err != nil { - // return true, err - // } - // if len(resultsSlice) == 1 { - // return true, nil - // } else { - // return true, errors.New("More than one record") - // } } // Count counts the records. bean's non-empty fields @@ -1083,7 +1064,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) if len(condiBean) > 0 { colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, false) + session.Statement.mustColumnMap) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } @@ -1442,6 +1423,8 @@ func (session *Session) getField(dataStruct *reflect.Value, key string, table *c return fieldValue } +type Cell *interface{} + func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount int, bean interface{}) error { dataStruct := rValue(bean) if dataStruct.Kind() != reflect.Struct { @@ -1450,18 +1433,24 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i table := session.Engine.autoMapType(dataStruct) - scanResultContainers := make([]interface{}, len(fields)) + scanResults := make([]interface{}, len(fields)) for i := 0; i < len(fields); i++ { - var scanResultContainer interface{} - scanResultContainers[i] = &scanResultContainer + var cell interface{} + scanResults[i] = &cell } - if err := rows.Scan(scanResultContainers...); err != nil { + if err := rows.Scan(scanResults...); err != nil { return err } + b, hasBeforeSet := bean.(BeforeSetProcessor) + for ii, key := range fields { + if hasBeforeSet { + b.BeforeSet(fields[ii], Cell(scanResults[ii].(*interface{}))) + } + if fieldValue := session.getField(&dataStruct, key, table); fieldValue != nil { - rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) + rawValue := reflect.Indirect(reflect.ValueOf(scanResults[ii])) //if row is null then ignore if rawValue.Interface() == nil { @@ -1469,7 +1458,18 @@ func (session *Session) row2Bean(rows *core.Rows, fields []string, fieldsCount i continue } - if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + if data, err := value2Bytes(&rawValue); err == nil { + structConvert.FromDB(data) + } else { + session.Engine.LogError(err) + } + continue + } + } + + if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { if data, err := value2Bytes(&rawValue); err == nil { structConvert.FromDB(data) } else { @@ -2452,6 +2452,16 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } } } + + if fieldConvert, ok := fieldValue.Interface().(core.Conversion); ok { + data, err := fieldConvert.ToDB() + if err != nil { + return 0, err + } else { + return string(data), nil + } + } + fieldType := fieldValue.Type() k := fieldType.Kind() if k == reflect.Ptr { @@ -2471,24 +2481,19 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val switch k { case reflect.Bool: return fieldValue.Bool(), nil - /*if fieldValue.Bool() { - return 1, nil - } else { - return 0, nil - }*/ case reflect.String: return fieldValue.String(), nil case reflect.Struct: if fieldType == core.TimeType { - t := fieldValue.Interface().(time.Time) - if session.Engine.dialect.DBType() == core.MSSQL { - if t.IsZero() { - return nil, nil - } - } switch fieldValue.Interface().(type) { case time.Time: - tf := session.Engine.FormatTime(col.SQLType.Name, fieldValue.Interface().(time.Time)) + t := fieldValue.Interface().(time.Time) + if session.Engine.dialect.DBType() == core.MSSQL { + if t.IsZero() { + return nil, nil + } + } + tf := session.Engine.FormatTime(col.SQLType.Name, t) return tf, nil default: return fieldValue.Interface(), nil @@ -2948,7 +2953,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 session.Statement.RefTable = table if session.Statement.ColumnStr == "" { - colNames, args = buildConditions(session.Engine, table, bean, false, false, + colNames, args = buildUpdates(session.Engine, table, bean, false, false, false, false, session.Statement.allUseBool, session.Statement.useAllCols, session.Statement.mustColumnMap, true) } else { @@ -2991,7 +2996,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, false) + session.Statement.mustColumnMap) } var condition = "" @@ -3195,7 +3200,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, false) + session.Statement.mustColumnMap) var condition = "" var andStr = session.Engine.dialect.AndStr() diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index 1e4f06bc..6e1ad41c 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -1,6 +1,7 @@ package xorm import ( + "fmt" "strings" "github.com/go-xorm/core" @@ -44,6 +45,10 @@ func (db *sqlite3) SqlType(c *core.Column) string { } } +func (db *sqlite3) FormatBytes(bs []byte) string { + return fmt.Sprintf("X'%x'", bs) +} + func (db *sqlite3) SupportInsertMany() bool { return true } diff --git a/statement.go b/statement.go index 7d6510f5..417d5b9d 100644 --- a/statement.go +++ b/statement.go @@ -257,11 +257,205 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { }*/ // Auto generating conditions according a struct -func buildConditions(engine *Engine, table *core.Table, bean interface{}, +func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, mustColumnMap map[string]bool, update bool) ([]string, []interface{}) { + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns() { + if !includeVersion && col.IsVersion { + continue + } + if col.IsCreated { + continue + } + if !includeUpdated && col.IsUpdated { + continue + } + if !includeAutoIncr && col.IsAutoIncrement { + continue + } + // + //fmt.Println(engine.dialect.DBType(), Text) + if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { + continue + } + fieldValuePtr, err := col.ValueOf(bean) + if err != nil { + engine.LogError(err) + continue + } + + fieldValue := *fieldValuePtr + fieldType := reflect.TypeOf(fieldValue.Interface()) + + requiredField := useAllCols + if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { + if b { + requiredField = true + } else { + continue + } + } + + var val interface{} + + if fieldValue.CanAddr() { + if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + engine.LogError(err) + } else { + val = data + } + continue + } + } + + if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { + data, err := structConvert.ToDB() + if err != nil { + engine.LogError(err) + } else { + val = data + } + continue + } + + if fieldType.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + if includeNil { + args = append(args, nil) + colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) + } + continue + } else if !fieldValue.IsValid() { + continue + } else { + // dereference ptr type to instance type + fieldValue = fieldValue.Elem() + fieldType = reflect.TypeOf(fieldValue.Interface()) + requiredField = true + } + } + + switch fieldType.Kind() { + case reflect.Bool: + if allUseBool || requiredField { + val = fieldValue.Interface() + } else { + // if a bool in a struct, it will not be as a condition because it default is false, + // please use Where() instead + continue + } + case reflect.String: + if !requiredField && fieldValue.String() == "" { + continue + } + // for MyString, should convert to string or panic + if fieldType.String() != reflect.String.String() { + val = fieldValue.String() + } else { + val = fieldValue.Interface() + } + case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: + if !requiredField && fieldValue.Int() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Float32, reflect.Float64: + if !requiredField && fieldValue.Float() == 0.0 { + continue + } + val = fieldValue.Interface() + case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: + if !requiredField && fieldValue.Uint() == 0 { + continue + } + val = fieldValue.Interface() + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { + continue + } + val = engine.FormatTime(col.SQLType.Name, t) + //fmt.Println("-------", t, val, col.Name) + } else { + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + if pkField.Int() != 0 { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + } + } else { + val = fieldValue.Interface() + } + } + case reflect.Array, reflect.Slice, reflect.Map: + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogError(err) + continue + } + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && + fieldType.Elem().Kind() == reflect.Uint8 { + if fieldValue.Len() > 0 { + val = fieldValue.Bytes() + } else { + continue + } + } else { + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogError(err) + continue + } + val = bytes + } + } else { + continue + } + default: + val = fieldValue.Interface() + } + + args = append(args, val) + if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { + continue + } + colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + } + + return colNames, args +} + +// Auto generating conditions according a struct +func buildConditions(engine *Engine, table *core.Table, bean interface{}, + includeVersion bool, includeUpdated bool, includeNil bool, + includeAutoIncr bool, allUseBool bool, useAllCols bool, + mustColumnMap map[string]bool) ([]string, []interface{}) { + colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns() { @@ -286,8 +480,11 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, } fieldValue := *fieldValuePtr - fieldType := reflect.TypeOf(fieldValue.Interface()) + if fieldValue.Interface() == nil { + continue + } + fieldType := reflect.TypeOf(fieldValue.Interface()) requiredField := useAllCols if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { if b { @@ -416,18 +613,10 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, args = append(args, val) var condi string - if update { - if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { - continue - } else { - condi = fmt.Sprintf("%v = ?", engine.Quote(col.Name)) - } + if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { + condi = "id() == ?" } else { - if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { - condi = "id() == ?" - } else { - condi = fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr()) - } + condi = fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr()) } colNames = append(colNames, condi) } @@ -493,10 +682,22 @@ func (statement *Statement) getInc() map[string]incrParam { // Generate "Where column IN (?) " statment func (statement *Statement) In(column string, args ...interface{}) *Statement { k := strings.ToLower(column) - if _, ok := statement.inColumns[k]; ok { - statement.inColumns[k].args = append(statement.inColumns[k].args, args...) + var newargs []interface{} + if len(args) == 1 && + reflect.TypeOf(args[0]).Kind() == reflect.Slice { + newargs = make([]interface{}, 0) + v := reflect.ValueOf(args[0]) + for i := 0; i < v.Len(); i++ { + newargs = append(newargs, v.Index(i).Interface()) + } } else { - statement.inColumns[k] = &inParam{column, args} + newargs = args + } + + if _, ok := statement.inColumns[k]; ok { + statement.inColumns[k].args = append(statement.inColumns[k].args, newargs...) + } else { + statement.inColumns[k] = &inParam{column, newargs} } return statement } @@ -750,7 +951,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, true, statement.allUseBool, statement.useAllCols, - statement.mustColumnMap, false) + statement.mustColumnMap) statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") statement.BeanArgs = args @@ -765,7 +966,7 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) func (s *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { quote := s.Engine.Quote - sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()), + sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(s.TableName()), col.String(s.Engine.dialect)) return sql, []interface{}{} } @@ -789,12 +990,13 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.RefTable = table colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - true, statement.allUseBool, statement.useAllCols, statement.mustColumnMap, false) + true, statement.allUseBool, statement.useAllCols, statement.mustColumnMap) statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ") statement.BeanArgs = args + // count(index fieldname) > count(0) > count(*) - var id string = "0" + var id string = "*" if statement.Engine.Dialect().DBType() == "ql" { id = "" } @@ -811,21 +1013,59 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { distinct = "DISTINCT " } - // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr, - statement.Engine.Quote(statement.TableName())) - if statement.JoinStr != "" { - a = fmt.Sprintf("%v %v", a, statement.JoinStr) + var top string + var mssqlCondi string + var orderBy string + if statement.OrderStr != "" { + orderBy = fmt.Sprintf(" ORDER BY %v", statement.OrderStr) } statement.processIdParam() + var whereStr string if statement.WhereStr != "" { - a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) + whereStr = fmt.Sprintf(" WHERE %v", statement.WhereStr) if statement.ConditionStr != "" { - a = fmt.Sprintf("%v %v %v", a, statement.Engine.Dialect().AndStr(), + whereStr = fmt.Sprintf("%v %s %v", whereStr, statement.Engine.Dialect().AndStr(), statement.ConditionStr) } } else if statement.ConditionStr != "" { - a = fmt.Sprintf("%v WHERE %v", a, statement.ConditionStr) + whereStr = fmt.Sprintf(" WHERE %v", statement.ConditionStr) + } + var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) + if statement.JoinStr != "" { + fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) + } + + if statement.Engine.dialect.DBType() == core.MSSQL { + if statement.LimitN > 0 { + top = fmt.Sprintf(" TOP %d ", statement.LimitN) + } + if statement.Start > 0 { + var column string = "(id)" + if len(statement.RefTable.PKColumns()) == 0 { + for _, index := range statement.RefTable.Indexes { + if len(index.Cols) == 1 { + column = index.Cols[0] + break + } + } + if len(column) == 0 { + column = statement.RefTable.ColumnsSeq()[0] + } + } + mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s))", + column, statement.Start, column, fromStr, whereStr, orderBy) + } + } + + // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern + a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, + fromStr, whereStr) + if mssqlCondi != "" { + if whereStr != "" { + a += " AND " + mssqlCondi + } else { + a += " WHERE " + mssqlCondi + } } if statement.GroupByStr != "" { @@ -843,11 +1083,6 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } else if statement.LimitN > 0 { a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } - } else { - //TODO: for mssql, should handler limit. - /*SELECT * FROM ( - SELECT *, ROW_NUMBER() OVER (ORDER BY id desc) as row FROM "userinfo" - ) a WHERE row > [start] and row <= [start+limit] order by id desc*/ } return