diff --git a/engine.go b/engine.go index 35817f3f..eff608df 100644 --- a/engine.go +++ b/engine.go @@ -1222,7 +1222,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } err = session.addColumn(col.Name) @@ -1233,7 +1233,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { } for name, index := range table.Indexes { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } if index.Type == core.UniqueType { @@ -1242,7 +1242,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -1257,7 +1257,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { return err } if !isExist { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } diff --git a/engine_table.go b/engine_table.go index bb74ac16..fbc1fddd 100644 --- a/engine_table.go +++ b/engine_table.go @@ -35,11 +35,10 @@ func (engine *Engine) tbNameForMap(v reflect.Value) string { t := v.Type() if tb, ok := v.Interface().(TableName); ok { return tb.TableName() - } else { - if v.CanAddr() { - if tb, ok = v.Addr().Interface().(TableName); ok { - return tb.TableName() - } + } + if v.CanAddr() { + if tb, ok := v.Addr().Interface().(TableName); ok { + return tb.TableName() } } return engine.TableMapper.Obj2Table(t.Name()) @@ -69,7 +68,7 @@ func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) { v := rValue(f) t := v.Type() if t.Kind() == reflect.Struct { - fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name())) + fmt.Fprintf(w, engine.tbNameForMap(v)) } else { fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", f))) } @@ -89,7 +88,7 @@ func (engine *Engine) tbNameNoSchema(w io.Writer, tablename interface{}) { v := rValue(tablename) t := v.Type() if t.Kind() == reflect.Struct { - fmt.Fprintf(w, engine.TableMapper.Obj2Table(v.Type().Name())) + fmt.Fprintf(w, engine.tbNameForMap(v)) } else { fmt.Fprintf(w, engine.Quote(fmt.Sprintf("%v", tablename))) } diff --git a/rows.go b/rows.go index 31e29ae2..54ec7f37 100644 --- a/rows.go +++ b/rows.go @@ -32,7 +32,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { var args []interface{} var err error - if err = rows.session.statement.setRefValue(rValue(bean)); err != nil { + if err = rows.session.statement.setRefBean(bean); err != nil { return nil, err } @@ -94,8 +94,7 @@ func (rows *Rows) Scan(bean interface{}) error { return fmt.Errorf("scan arg is incompatible type to [%v]", rows.beanType) } - dataStruct := rValue(bean) - if err := rows.session.statement.setRefValue(dataStruct); err != nil { + if err := rows.session.statement.setRefBean(bean); err != nil { return err } @@ -104,6 +103,7 @@ func (rows *Rows) Scan(bean interface{}) error { return err } + dataStruct := rValue(bean) _, err = rows.session.slice2Bean(scanResults, rows.fields, bean, &dataStruct, rows.session.statement.RefTable) if err != nil { return err diff --git a/session_delete.go b/session_delete.go index 688b122c..eb91614c 100644 --- a/session_delete.go +++ b/session_delete.go @@ -79,7 +79,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { defer session.Close() } - if err := session.statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } diff --git a/session_exist.go b/session_exist.go index 378a6483..74a660e8 100644 --- a/session_exist.go +++ b/session_exist.go @@ -57,7 +57,7 @@ func (session *Session) Exist(bean ...interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefBean(bean[0]); err != nil { return false, err } } diff --git a/session_get.go b/session_get.go index 68b37af7..58191de1 100644 --- a/session_get.go +++ b/session_get.go @@ -31,7 +31,7 @@ func (session *Session) get(bean interface{}) (bool, error) { } if beanValue.Elem().Kind() == reflect.Struct { - if err := session.statement.setRefValue(beanValue.Elem()); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return false, err } } diff --git a/session_insert.go b/session_insert.go index 129ee230..8609b80c 100644 --- a/session_insert.go +++ b/session_insert.go @@ -298,7 +298,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) innerInsert(bean interface{}) (int64, error) { - if err := session.statement.setRefValue(rValue(bean)); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } if len(session.statement.TableName()) <= 0 { diff --git a/session_schema.go b/session_schema.go index e079c6ca..fad811b8 100644 --- a/session_schema.go +++ b/session_schema.go @@ -6,9 +6,7 @@ package xorm import ( "database/sql" - "errors" "fmt" - "reflect" "strings" "github.com/go-xorm/core" @@ -34,8 +32,7 @@ func (session *Session) CreateTable(bean interface{}) error { } func (session *Session) createTable(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -54,8 +51,7 @@ func (session *Session) CreateIndexes(bean interface{}) error { } func (session *Session) createIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -78,8 +74,7 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createUniques(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -103,8 +98,7 @@ func (session *Session) DropIndexes(bean interface{}) error { } func (session *Session) dropIndexes(bean interface{}) error { - v := rValue(bean) - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return err } @@ -168,19 +162,10 @@ func (session *Session) isTableExist(tableName string) (bool, error) { // IsTableEmpty if table have any records func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { - v := rValue(bean) - t := v.Type() - - if t.Kind() == reflect.String { - if session.isAutoClose { - defer session.Close() - } - return session.isTableEmpty(bean.(string)) - } else if t.Kind() == reflect.Struct { - rows, err := session.Count(bean) - return rows == 0, err + if session.isAutoClose { + defer session.Close() } - return false, errors.New("bean should be a struct or struct's point") + return session.isTableEmpty(session.engine.tbNameNoSchemaString(bean)) } func (session *Session) isTableEmpty(tableName string) (bool, error) { diff --git a/session_update.go b/session_update.go index f5587456..11264a61 100644 --- a/session_update.go +++ b/session_update.go @@ -167,7 +167,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var isMap = t.Kind() == reflect.Map var isStruct = t.Kind() == reflect.Struct if isStruct { - if err := session.statement.setRefValue(v); err != nil { + if err := session.statement.setRefBean(bean); err != nil { return 0, err } diff --git a/statement.go b/statement.go index 2e6a2a6d..5a2802d9 100644 --- a/statement.go +++ b/statement.go @@ -225,6 +225,16 @@ func (statement *Statement) setRefValue(v reflect.Value) error { return nil } +func (statement *Statement) setRefBean(bean interface{}) error { + var err error + statement.RefTable, err = statement.Engine.autoMapType(rValue(bean)) + if err != nil { + return err + } + statement.tableName = statement.Engine.TableNameWithSchema(statement.Engine.tbNameNoSchemaString(bean)) + return nil +} + // Auto generating update columnes and values according a struct func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, @@ -918,7 +928,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, v := rValue(bean) isStruct := v.Kind() == reflect.Struct if isStruct { - statement.setRefValue(v) + statement.setRefBean(bean) } var columnStr = statement.ColumnStr @@ -970,7 +980,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa var condArgs []interface{} var err error if len(beans) > 0 { - statement.setRefValue(rValue(beans[0])) + statement.setRefBean(beans[0]) condSQL, condArgs, err = statement.genConds(beans[0]) } else { condSQL, condArgs, err = builder.ToSQL(statement.cond) @@ -996,7 +1006,7 @@ func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interfa } func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { - statement.setRefValue(rValue(bean)) + statement.setRefBean(bean) var sumStrs = make([]string, 0, len(columns)) for _, colName := range columns {