diff --git a/session_insert.go b/session_insert.go index e673e874..03d8962c 100644 --- a/session_insert.go +++ b/session_insert.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "reflect" + "sort" "strconv" "strings" @@ -24,32 +25,67 @@ func (session *Session) Insert(beans ...interface{}) (int64, error) { } for _, bean := range beans { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - if sliceValue.Kind() == reflect.Slice { - size := sliceValue.Len() - if size > 0 { - if session.engine.SupportInsertMany() { - cnt, err := session.innerInsertMulti(bean) - if err != nil { - return affected, err - } - affected += cnt - } else { - for i := 0; i < size; i++ { - cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) - if err != nil { - return affected, err - } - affected += cnt - } - } - } - } else { - cnt, err := session.innerInsert(bean) + switch bean.(type) { + case map[string]interface{}: + cnt, err := session.insertMapInterface(bean.(map[string]interface{})) if err != nil { return affected, err } affected += cnt + case []map[string]interface{}: + s := bean.([]map[string]interface{}) + session.autoResetStatement = false + for i := 0; i < len(s); i++ { + cnt, err := session.insertMapInterface(s[i]) + if err != nil { + return affected, err + } + affected += cnt + } + case map[string]string: + cnt, err := session.insertMapString(bean.(map[string]string)) + if err != nil { + return affected, err + } + affected += cnt + case []map[string]string: + s := bean.([]map[string]string) + session.autoResetStatement = false + for i := 0; i < len(s); i++ { + cnt, err := session.insertMapString(s[i]) + if err != nil { + return affected, err + } + affected += cnt + } + default: + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + if sliceValue.Kind() == reflect.Slice { + size := sliceValue.Len() + if size > 0 { + if session.engine.SupportInsertMany() { + cnt, err := session.innerInsertMulti(bean) + if err != nil { + return affected, err + } + affected += cnt + } else { + for i := 0; i < size; i++ { + cnt, err := session.innerInsert(sliceValue.Index(i).Interface()) + if err != nil { + return affected, err + } + affected += cnt + } + } + } + } else { + cnt, err := session.innerInsert(bean) + if err != nil { + return affected, err + } + affected += cnt + } } } @@ -622,3 +658,67 @@ func (session *Session) genInsertColumns(bean interface{}) ([]string, []interfac } return colNames, args, nil } + +func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) { + var columns = make([]string, 0, len(m)) + for k := range m { + columns = append(columns, k) + } + sort.Strings(columns) + + qm := strings.Repeat("?,", len(columns)) + qm = "(" + qm[:len(qm)-1] + ")" + + tableName := session.statement.AltTableName + var sql = "INSERT INTO `" + tableName + "` (`" + strings.Join(columns, "`,`") + "`) VALUES " + qm + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, nil +} + +func (session *Session) insertMapString(m map[string]string) (int64, error) { + var columns = make([]string, 0, len(m)) + for k := range m { + columns = append(columns, k) + } + sort.Strings(columns) + + qm := strings.Repeat("?,", len(columns)) + qm = "(" + qm[:len(qm)-1] + ")" + + tableName := session.statement.AltTableName + var sql = "INSERT INTO `" + tableName + "` (`" + strings.Join(columns, "`,`") + "`) VALUES " + qm + var args = make([]interface{}, 0, len(m)) + for _, colName := range columns { + args = append(args, m[colName]) + } + + if err := session.cacheInsert(tableName); err != nil { + return 0, err + } + + res, err := session.exec(sql, args...) + if err != nil { + return 0, err + } + affected, err := res.RowsAffected() + if err != nil { + return 0, err + } + return affected, nil +} diff --git a/session_insert_test.go b/session_insert_test.go index 50943032..2027fbb7 100644 --- a/session_insert_test.go +++ b/session_insert_test.go @@ -780,3 +780,82 @@ func TestAnonymousStruct(t *testing.T) { }) assert.NoError(t, err) } + +func TestInsertMap(t *testing.T) { + type InsertMap struct { + Id int64 + Width uint32 + Height uint32 + Name string + } + + assert.NoError(t, prepareEngine()) + assertSync(t, new(InsertMap)) + + cnt, err := testEngine.Table(new(InsertMap)).Insert(map[string]interface{}{ + "width": 20, + "height": 10, + "name": "lunny", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var im InsertMap + has, err := testEngine.Get(&im) + assert.NoError(t, err) + assert.True(t, has) + assert.EqualValues(t, 20, im.Width) + assert.EqualValues(t, 10, im.Height) + assert.EqualValues(t, "lunny", im.Name) + + cnt, err = testEngine.Table("insert_map").Insert(map[string]interface{}{ + "width": 30, + "height": 10, + "name": "lunny", + }) + assert.NoError(t, err) + assert.EqualValues(t, 1, cnt) + + var ims []InsertMap + err = testEngine.Find(&ims) + assert.NoError(t, err) + assert.EqualValues(t, 2, len(ims)) + assert.EqualValues(t, 20, ims[0].Width) + assert.EqualValues(t, 10, ims[0].Height) + assert.EqualValues(t, "lunny", ims[0].Name) + assert.EqualValues(t, 30, ims[1].Width) + assert.EqualValues(t, 10, ims[1].Height) + assert.EqualValues(t, "lunny", ims[1].Name) + + cnt, err = testEngine.Table("insert_map").Insert([]map[string]interface{}{ + { + "width": 40, + "height": 10, + "name": "lunny", + }, + { + "width": 50, + "height": 10, + "name": "lunny", + }, + }) + assert.NoError(t, err) + assert.EqualValues(t, 2, cnt) + + ims = make([]InsertMap, 0, 4) + err = testEngine.Find(&ims) + assert.NoError(t, err) + assert.EqualValues(t, 4, len(ims)) + assert.EqualValues(t, 20, ims[0].Width) + assert.EqualValues(t, 10, ims[0].Height) + assert.EqualValues(t, "lunny", ims[1].Name) + assert.EqualValues(t, 30, ims[1].Width) + assert.EqualValues(t, 10, ims[1].Height) + assert.EqualValues(t, "lunny", ims[1].Name) + assert.EqualValues(t, 40, ims[2].Width) + assert.EqualValues(t, 10, ims[2].Height) + assert.EqualValues(t, "lunny", ims[2].Name) + assert.EqualValues(t, 50, ims[3].Width) + assert.EqualValues(t, 10, ims[3].Height) + assert.EqualValues(t, "lunny", ims[3].Name) +}