diff --git a/README.md b/README.md index 3c69a35b..25fbc7b2 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ Or # Cases +* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) + * [Wego](http://github.com/go-tango/wego) * [Docker.cn](https://docker.cn/) diff --git a/README_CN.md b/README_CN.md index 07a26284..fb08040b 100644 --- a/README_CN.md +++ b/README_CN.md @@ -79,6 +79,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 ## 案例 +* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) + * [Wego](http://github.com/go-tango/wego) * [Docker.cn](https://docker.cn/) diff --git a/VERSION b/VERSION index 988f30d1..af3e71ec 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -xorm v0.4.3.0428 +xorm v0.4.3.0520 diff --git a/engine.go b/engine.go index 915912de..94ac29e3 100644 --- a/engine.go +++ b/engine.go @@ -517,6 +517,12 @@ func (engine *Engine) Distinct(columns ...string) *Session { return session.Distinct(columns...) } +func (engine *Engine) Select(str string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Select(str) +} + // only use the paramters as select or update columns func (engine *Engine) Cols(columns ...string) *Session { session := engine.NewSession() @@ -653,20 +659,18 @@ func (engine *Engine) Having(conditions string) *Session { func (engine *Engine) autoMapType(v reflect.Value) *core.Table { t := v.Type() - engine.mutex.RLock() + engine.mutex.Lock() table, ok := engine.Tables[t] - engine.mutex.RUnlock() if !ok { table = engine.mapType(v) - engine.mutex.Lock() engine.Tables[t] = table if v.CanAddr() { engine.GobRegister(v.Addr().Interface()) } else { engine.GobRegister(v.Interface()) } - engine.mutex.Unlock() } + engine.mutex.Unlock() return table } @@ -1123,7 +1127,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { session := engine.NewSession() session.Statement.RefTable = table defer session.Close() - isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col) + isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col.Name) if err != nil { return err } diff --git a/helpers.go b/helpers.go index 5208137a..979a67a1 100644 --- a/helpers.go +++ b/helpers.go @@ -133,8 +133,8 @@ func reflect2value(rawValue *reflect.Value) (str string, err error) { } //时间类型 case reflect.Struct: - if aa == core.TimeType { - str = rawValue.Interface().(time.Time).Format(time.RFC3339Nano) + if aa.ConvertibleTo(core.TimeType) { + str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano) } else { err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name()) } diff --git a/mssql_dialect.go b/mssql_dialect.go index a910406b..57985ad9 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -315,10 +315,10 @@ func (db *mssql) IndexCheckSql(tableName, idxName string) (string, []interface{} return sql, args }*/ -func (db *mssql) IsColumnExist(tableName string, col *core.Column) (bool, error) { +func (db *mssql) IsColumnExist(tableName, colName string) (bool, error) { query := `SELECT "COLUMN_NAME" FROM "INFORMATION_SCHEMA"."COLUMNS" WHERE "TABLE_NAME" = ? AND "COLUMN_NAME" = ?` - return db.HasRecords(query, tableName, col.Name) + return db.HasRecords(query, tableName, colName) } func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { diff --git a/oracle_dialect.go b/oracle_dialect.go index b4f89edf..dc64c00e 100644 --- a/oracle_dialect.go +++ b/oracle_dialect.go @@ -665,8 +665,8 @@ func (db *oracle) MustDropTable(tableName string) error { " AND column_name = ?", args }*/ -func (db *oracle) IsColumnExist(tableName string, col *core.Column) (bool, error) { - args := []interface{}{tableName, col.Name} +func (db *oracle) IsColumnExist(tableName, colName string) (bool, error) { + args := []interface{}{tableName, colName} query := "SELECT column_name FROM USER_TAB_COLUMNS WHERE table_name = :1" + " AND column_name = :2" rows, err := db.DB().Query(query, args...) diff --git a/postgres_dialect.go b/postgres_dialect.go index 04713cf5..67ceecd0 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -896,8 +896,8 @@ func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } -func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, error) { - args := []interface{}{tableName, col.Name} +func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) { + args := []interface{}{tableName, colName} query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" rows, err := db.DB().Query(query, args...) diff --git a/session.go b/session.go index a701c2bc..adf0aa8b 100644 --- a/session.go +++ b/session.go @@ -178,6 +178,12 @@ func (session *Session) SetExpr(column string, expression string) *Session { return session } +// Method Cols provides some columns to special +func (session *Session) Select(str string) *Session { + session.Statement.Select(str) + return session +} + // Method Cols provides some columns to special func (session *Session) Cols(columns ...string) *Session { session.Statement.Cols(columns...) @@ -627,12 +633,20 @@ func (statement *Statement) convertIdSql(sqlStr string) string { return "" } -func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { - // if has no reftable, then don't use cache currently +func (session *Session) canCache() bool { if session.Statement.RefTable == nil || session.Statement.JoinStr != "" || session.Statement.RawSQL != "" || - session.Tx != nil { + session.Tx != nil || + len(session.Statement.selectStr) > 0 { + return false + } + return true +} + +func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { + // if has no reftable, then don't use cache currently + if !session.canCache() { return false, ErrCacheFailed } @@ -730,10 +744,9 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { - if session.Statement.RefTable == nil || + if !session.canCache() || indexNoCase(sqlStr, "having") != -1 || - indexNoCase(sqlStr, "group by") != -1 || - session.Tx != nil { + indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed } @@ -883,7 +896,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for j := 0; j < len(temps); j++ { bean := temps[j] if bean == nil { - session.Engine.LogWarn("[cacheFind] cache no hit:", tableName, ides[j], temps) + session.Engine.LogWarn("[cacheFind] cache no hit:", tableName, ids[j], temps) // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead continue } @@ -1008,9 +1021,9 @@ func (session *Session) Get(bean interface{}) (bool, error) { var err error session.queryPreprocess(&sqlStr, args...) if session.IsAutoCommit { - stmt, err := session.doPrepare(sqlStr) - if err != nil { - return false, err + stmt, errPrepare := session.doPrepare(sqlStr) + if errPrepare != nil { + return false, errPrepare } // defer stmt.Close() // !nashtsai! don't close due to stmt is cached and bounded to this session rawRows, err = stmt.Query(args...) @@ -1171,9 +1184,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } if len(condiBean) > 0 { + var addedTableName = (len(session.Statement.JoinStr) > 0) colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap) + session.Statement.unscoped, session.Statement.mustColumnMap, + session.Statement.TableName(), addedTableName) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } else { @@ -1189,20 +1204,24 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr - if session.Statement.JoinStr == "" { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = session.Statement.genColumnStr() - } - } + if len(session.Statement.selectStr) > 0 { + columnStr = session.Statement.selectStr } else { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = "*" + if session.Statement.JoinStr == "" { + if columnStr == "" { + if session.Statement.GroupByStr != "" { + columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) + } else { + columnStr = session.Statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if session.Statement.GroupByStr != "" { + columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) + } else { + columnStr = "*" + } } } } @@ -2902,20 +2921,15 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val case reflect.String: return fieldValue.String(), nil case reflect.Struct: - if fieldType == core.TimeType { - switch fieldValue.Interface().(type) { - case time.Time: - t := fieldValue.Interface().(time.Time) - if session.Engine.dialect.DBType() == core.MSSQL { - if t.IsZero() { - return nil, nil - } + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) + if session.Engine.dialect.DBType() == core.MSSQL { + if t.IsZero() { + return nil, nil } - tf := session.Engine.FormatTime(col.SQLType.Name, t) - return tf, nil - default: - return fieldValue.Interface(), nil } + tf := session.Engine.FormatTime(col.SQLType.Name, t) + return tf, nil } if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { if len(fieldTable.PrimaryKeys) == 1 { @@ -2925,7 +2939,7 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val return 0, fmt.Errorf("no primary key for col %v", col.Name) } } else { - return 0, fmt.Errorf("Unsupported type %v\n", fieldValue.Type()) + return 0, fmt.Errorf("Unsupported type %v", fieldValue.Type()) } case reflect.Complex64, reflect.Complex128: bytes, err := json.Marshal(fieldValue.Interface()) @@ -3461,7 +3475,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 if len(condiBean) > 0 { condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap) + session.Statement.unscoped, session.Statement.mustColumnMap, session.Statement.TableName(), false) } var condition = "" @@ -3681,7 +3695,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { session.Statement.RefTable = table colNames, args := buildConditions(session.Engine, table, bean, true, true, false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap) + session.Statement.unscoped, session.Statement.mustColumnMap, + session.Statement.TableName(), false) var condition = "" var andStr = session.Engine.dialect.AndStr() diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index 8a770f78..c525e5be 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -249,9 +249,9 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { return sql, args }*/ -func (db *sqlite3) IsColumnExist(tableName string, col *core.Column) (bool, error) { +func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { args := []interface{}{tableName} - query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + col.Name + "`%') or (sql like '%[" + col.Name + "]%'))" + query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" rows, err := db.DB().Query(query, args...) if db.Logger != nil { db.Logger.Info("[sql]", query, args) diff --git a/statement.go b/statement.go index 6e1e87c2..3278662d 100644 --- a/statement.go +++ b/statement.go @@ -49,6 +49,7 @@ type Statement struct { GroupByStr string HavingStr string ColumnStr string + selectStr string columnMap map[string]bool useAllCols bool OmitStr string @@ -102,6 +103,7 @@ func (statement *Statement) Init() { statement.UseAutoTime = true statement.IsDistinct = false statement.TableAlias = "" + statement.selectStr = "" statement.allUseBool = false statement.useAllCols = false statement.mustColumnMap = make(map[string]bool) @@ -185,122 +187,6 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { return statement } -/*func (statement *Statement) genFields(bean interface{}) map[string]interface{} { - results := make(map[string]interface{}) - table := statement.Engine.TableInfo(bean) - for _, col := range table.Columns { - fieldValue := col.ValueOf(bean) - fieldType := reflect.TypeOf(fieldValue.Interface()) - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool { - val = fieldValue.Interface() - } else if _, ok := boolColumnMap[col.Name]; ok { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) - if t.IsZero() || !fieldValue.IsValid() { - continue - } - var str string - if col.SQLType.Name == Time { - s := t.UTC().Format("2006-01-02 15:04:05") - val = s[11:19] - } else if col.SQLType.Name == Date { - str = t.Format("2006-01-02") - val = str - } else { - val = t - } - } else { - engine.autoMapType(fieldValue.Type()) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) - if pkField.Int() != 0 { - val = pkField.Interface() - } else { - continue - } - } else { - val = fieldValue.Interface() - } - } - case reflect.Array, reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.LogError(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.LogError(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - results[col.Name] = val - } - return results -}*/ - // Auto generating conditions according a struct func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, @@ -429,8 +315,8 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, t := int64(fieldValue.Uint()) val = reflect.ValueOf(&t).Interface() case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } @@ -511,8 +397,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, func buildConditions(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool) ([]string, []interface{}) { - + mustColumnMap map[string]bool, tableName string, addedTableName bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns() { @@ -529,6 +414,14 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } + + var colName string + if addedTableName { + colName = engine.Quote(tableName)+"."+engine.Quote(col.Name) + } else { + colName = engine.Quote(col.Name) + } + fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -536,7 +429,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, } if col.IsDeleted && !unscoped { // tag "deleted" is enabled - colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", engine.Quote(col.Name), engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", + colName, colName)) } fieldValue := *fieldValuePtr @@ -558,7 +452,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if fieldValue.IsNil() { if includeNil { args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr())) + colNames = append(colNames, fmt.Sprintf("%v %s ?", colName, engine.dialect.EqStr())) } continue } else if !fieldValue.IsValid() { @@ -630,7 +524,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, } } else { //TODO: how to handler? - panic("not supported") + panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) } } else { val = fieldValue.Interface() @@ -681,7 +575,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { condi = "id() == ?" } else { - condi = fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr()) + condi = fmt.Sprintf("%v %s ?", colName, engine.dialect.EqStr()) } colNames = append(colNames, condi) } @@ -877,6 +771,12 @@ func (statement *Statement) Distinct(columns ...string) *Statement { return statement } +// replace select +func (s *Statement) Select(str string) *Statement { + s.selectStr = str + return s +} + // Generate "col1, col2" statement func (statement *Statement) Cols(columns ...string) *Statement { newColumns := col2NewCols(columns...) @@ -1153,9 +1053,11 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) table = statement.RefTable } + var addedTableName = (len(statement.JoinStr) > 0) + colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, true, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap) + statement.unscoped, statement.mustColumnMap, statement.TableName(), addedTableName) if !statement.IsWhereOnly { statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") @@ -1163,20 +1065,24 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) } var columnStr string = statement.ColumnStr - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if statement.GroupByStr != "" { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = statement.genColumnStr() - } - } + if len(statement.selectStr) > 0 { + columnStr = statement.selectStr } else { - if len(columnStr) == 0 { - if statement.GroupByStr != "" { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = "*" + if len(statement.JoinStr) == 0 { + if len(columnStr) == 0 { + if statement.GroupByStr != "" { + columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if len(columnStr) == 0 { + if statement.GroupByStr != "" { + columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + } else { + columnStr = "*" + } } } } @@ -1210,9 +1116,11 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} table := statement.Engine.TableInfo(bean) statement.RefTable = table + var addedTableName = (len(statement.JoinStr) > 0) + colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, true, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap) + statement.unscoped, statement.mustColumnMap, statement.TableName(), addedTableName) if !statement.IsWhereOnly { statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ") diff --git a/xorm.go b/xorm.go index 8b6cc568..036d420f 100644 --- a/xorm.go +++ b/xorm.go @@ -17,7 +17,7 @@ import ( ) const ( - Version string = "0.4.3.0428" + Version string = "0.4.3.0526" ) func regDrvsNDialects() bool {