diff --git a/engine.go b/engine.go index e1a9946e..146b2ee0 100644 --- a/engine.go +++ b/engine.go @@ -157,6 +157,20 @@ func (engine *Engine) MapOne(bean interface{}) Table { return engine.MapType(t) } +func (engine *Engine) AutoMapType(t reflect.Type) *Table { + table, ok := engine.Tables[t] + if !ok { + table = engine.MapType(t) + engine.Tables[t] = table + } + return &table +} + +func (engine *Engine) AutoMap(bean interface{}) *Table { + t := Type(bean) + return engine.AutoMapType(t) +} + func (engine *Engine) MapType(t reflect.Type) Table { table := Table{Name: engine.Mapper.Obj2Table(t.Name()), Type: t} table.Columns = make(map[string]Column) @@ -183,6 +197,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { switch { case k == "pk": col.IsPrimaryKey = true + col.Nullable = false pkCol = &col case k == "null": col.Nullable = (tags[j-1] != "not") diff --git a/session.go b/session.go index 96fa79cc..70ae825d 100644 --- a/session.go +++ b/session.go @@ -211,7 +211,8 @@ func (session *Session) Get(bean interface{}) error { statement := session.Statement defer session.Statement.Init() statement.Limit(1) - table := session.Engine.Bean2Table(bean) + + table := session.Engine.AutoMap(bean) statement.Table = table colNames, args := session.BuildConditions(table, bean) @@ -241,7 +242,7 @@ func (session *Session) Get(bean interface{}) error { func (session *Session) Count(bean interface{}) (int64, error) { statement := session.Statement defer session.Statement.Init() - table := session.Engine.Bean2Table(bean) + table := session.Engine.AutoMap(bean) statement.Table = table colNames, args := session.BuildConditions(table, bean) @@ -271,12 +272,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } sliceElementType := sliceValue.Type().Elem() - - table := session.Engine.Tables[sliceElementType] - statement.Table = &table + table := session.Engine.AutoMapType(sliceElementType) + statement.Table = table if len(condiBean) > 0 { - colNames, args := session.BuildConditions(&table, condiBean[0]) + colNames, args := session.BuildConditions(table, condiBean[0]) statement.ColumnStr = strings.Join(colNames, " and ") statement.BeanArgs = args } @@ -428,8 +428,8 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { bean := sliceValue.Index(0).Interface() sliceElementType := Type(bean) - table := session.Engine.Tables[sliceElementType] - session.Statement.Table = &table + table := session.Engine.AutoMapType(sliceElementType) + session.Statement.Table = table size := sliceValue.Len() @@ -490,7 +490,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { } func (session *Session) InsertOne(bean interface{}) (int64, error) { - table := session.Engine.Bean2Table(bean) + table := session.Engine.AutoMap(bean) session.Statement.Table = table colNames := make([]string, 0) colPlaces := make([]string, 0) @@ -561,7 +561,7 @@ func (session *Session) BuildConditions(table *Table, bean interface{}) ([]strin } func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { - table := session.Engine.Bean2Table(bean) + table := session.Engine.AutoMap(bean) session.Statement.Table = table colNames, args := session.BuildConditions(table, bean) var condiColNames []string @@ -610,7 +610,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } func (session *Session) Delete(bean interface{}) (int64, error) { - table := session.Engine.Bean2Table(bean) + table := session.Engine.AutoMap(bean) session.Statement.Table = table colNames, args := session.BuildConditions(table, bean)