diff --git a/base_test.go b/base_test.go index d2ea0606..d4c15f15 100644 --- a/base_test.go +++ b/base_test.go @@ -1448,12 +1448,32 @@ func testIndexAndUnique(engine *Engine, t *testing.T) { } type IntId struct { - Id int + Id int `xorm:"pk autoincr"` Name string } type Int32Id struct { - Id int32 + Id int32 `xorm:"pk autoincr"` + Name string +} + +type UintId struct { + Id uint `xorm:"pk autoincr"` + Name string +} + +type Uint32Id struct { + Id uint32 `xorm:"pk autoincr"` + Name string +} + +type Uint64Id struct { + Id uint64 `xorm:"pk autoincr"` + Name string +} + +type StringPK struct { + Id string `xorm:"pk notnull"` Name string } @@ -1470,11 +1490,51 @@ func testIntId(engine *Engine, t *testing.T) { panic(err) } - _, err = engine.Insert(&IntId{Name: "test"}) + cnt, err := engine.Insert(&IntId{Name: "test"}) if err != nil { t.Error(err) panic(err) } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(IntId) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]IntId, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&IntId{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } } func testInt32Id(engine *Engine, t *testing.T) { @@ -1490,11 +1550,295 @@ func testInt32Id(engine *Engine, t *testing.T) { panic(err) } - _, err = engine.Insert(&Int32Id{Name: "test"}) + cnt, err := engine.Insert(&Int32Id{Name: "test"}) if err != nil { t.Error(err) panic(err) } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(Int32Id) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]Int32Id, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&Int32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testUintId(engine *Engine, t *testing.T) { + err := engine.DropTables(&UintId{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&UintId{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&UintId{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(UintId) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]UintId, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&UintId{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testUint32Id(engine *Engine, t *testing.T) { + err := engine.DropTables(&Uint32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&Uint32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&Uint32Id{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(Uint32Id) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]Uint32Id, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&Uint32Id{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testUint64Id(engine *Engine, t *testing.T) { + err := engine.DropTables(&Uint64Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&Uint64Id{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&Uint64Id{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(Uint64Id) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]Uint64Id, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&Uint64Id{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } +} + +func testStringPK(engine *Engine, t *testing.T) { + err := engine.DropTables(&StringPK{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&StringPK{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&StringPK{Name: "test"}) + if err != nil { + t.Error(err) + panic(err) + } + + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } + + bean := new(StringPK) + has, err := engine.Get(bean) + if err != nil { + t.Error(err) + panic(err) + } + if !has { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + beans := make([]StringPK, 0) + err = engine.Find(&beans) + if err != nil { + t.Error(err) + panic(err) + } + if len(beans) != 1 { + err = errors.New("get count should be one") + t.Error(err) + panic(err) + } + + cnt, err = engine.Id(bean.Id).Delete(&StringPK{}) + if err != nil { + t.Error(err) + panic(err) + } + if cnt != 1 { + err = errors.New("insert count should be one") + t.Error(err) + panic(err) + } } func testMetaInfo(engine *Engine, t *testing.T) { @@ -3407,9 +3751,15 @@ func testAll2(engine *Engine, t *testing.T) { fmt.Println("-------------- testIndexAndUnique --------------") testIndexAndUnique(engine, t) fmt.Println("-------------- testIntId --------------") - //testIntId(engine, t) + testIntId(engine, t) fmt.Println("-------------- testInt32Id --------------") - //testInt32Id(engine, t) + testInt32Id(engine, t) + fmt.Println("-------------- testUintId --------------") + testUintId(engine, t) + fmt.Println("-------------- testUint32Id --------------") + testUint32Id(engine, t) + fmt.Println("-------------- testUint64Id --------------") + testUint64Id(engine, t) fmt.Println("-------------- testMetaInfo --------------") testMetaInfo(engine, t) fmt.Println("-------------- testIterate --------------") @@ -3446,4 +3796,6 @@ func testAll3(engine *Engine, t *testing.T) { testNullValue(engine, t) fmt.Println("-------------- testCompositeKey --------------") testCompositeKey(engine, t) + fmt.Println("-------------- testStringPK --------------") + testStringPK(engine, t) } diff --git a/engine.go b/engine.go index 65d2f095..ee31e47c 100644 --- a/engine.go +++ b/engine.go @@ -476,7 +476,7 @@ func (engine *Engine) mapType(t reflect.Type) *Table { table.ColumnsSeq = append(table.ColumnsSeq, name) } - table.PrimaryKey = parentTable.PrimaryKey + table.PrimaryKeys = parentTable.PrimaryKeys continue } var indexType int @@ -602,12 +602,13 @@ func (engine *Engine) mapType(t reflect.Type) *Table { } } - if idFieldColName != "" && table.PrimaryKey == "" { + if idFieldColName != "" && len(table.PrimaryKeys) == 0 { col := table.Columns[idFieldColName] col.IsPrimaryKey = true col.IsAutoIncrement = true col.Nullable = false - table.PrimaryKey = col.Name + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + table.AutoIncrement = col.Name } return table diff --git a/filter.go b/filter.go index 5fff4c0d..d2b6c468 100644 --- a/filter.go +++ b/filter.go @@ -40,10 +40,10 @@ type IdFilter struct { } func (i *IdFilter) Do(sql string, session *Session) string { - if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { - sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) - return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKey), -1) + if session.Statement.RefTable != nil && len(session.Statement.RefTable.PrimaryKeys) == 1 { + sql = strings.Replace(sql, "`(id)`", session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1) + sql = strings.Replace(sql, session.Engine.Quote("(id)"), session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1) + return strings.Replace(sql, "(id)", session.Engine.Quote(session.Statement.RefTable.PrimaryKeys[0]), -1) } return sql } diff --git a/postgres.go b/postgres.go index c316f9b5..9e12b009 100644 --- a/postgres.go +++ b/postgres.go @@ -67,7 +67,11 @@ func (db *postgres) SqlType(c *Column) string { switch t := c.SQLType.Name; t { case TinyInt: res = SmallInt + case MediumInt, Int, Integer: + if c.IsAutoIncrement { + return Serial + } return Integer case Serial, BigSerial: c.IsAutoIncrement = true diff --git a/postgres_test.go b/postgres_test.go index e7cc363f..907b4e15 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -8,11 +8,11 @@ import ( ) func newPostgresEngine() (*Engine, error) { - return NewEngine("postgres", "dbname=xorm_test sslmode=disable") + return NewEngine("postgres", "dbname=xorm_test user=lunny password=1234 sslmode=disable") } func newPostgresDriverDB() (*sql.DB, error) { - return sql.Open("postgres", "dbname=xorm_test sslmode=disable") + return sql.Open("postgres", "dbname=xorm_test user=lunny password=1234 sslmode=disable") } func TestPostgres(t *testing.T) { diff --git a/session.go b/session.go index 4fbe99db..03f1c637 100644 --- a/session.go +++ b/session.go @@ -577,7 +577,7 @@ func (session *Session) DropTable(bean interface{}) error { func (statement *Statement) convertIdSql(sql string) string { if statement.RefTable != nil { - col := statement.RefTable.PKColumn() + col := statement.RefTable.PKColumns()[0] if col != nil { sqls := splitNNoCase(sql, "from", 2) if len(sqls) != 2 { @@ -592,7 +592,8 @@ func (statement *Statement) convertIdSql(sql string) string { } func (session *Session) cacheGet(bean interface{}, sql string, args ...interface{}) (has bool, err error) { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + // if has no reftable or number of pks is not equal to 1, then don't use cache currently + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return false, ErrCacheFailed } for _, filter := range session.Engine.Filters { @@ -617,7 +618,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface if len(resultsSlice) > 0 { data := resultsSlice[0] var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return false, ErrCacheFailed } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -672,7 +673,7 @@ func (session *Session) cacheGet(bean interface{}, sql string, args ...interface func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr interface{}, args ...interface{}) (err error) { if session.Statement.RefTable == nil || - session.Statement.RefTable.PrimaryKey == "" || + len(session.Statement.RefTable.PrimaryKeys) != 1 || indexNoCase(sql, "having") != -1 || indexNoCase(sql, "group by") != -1 { return ErrCacheFailed @@ -708,7 +709,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter for _, data := range resultsSlice { //fmt.Println(data) var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -729,7 +730,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter } sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) - pkFieldName := session.Statement.RefTable.PKColumn().FieldName + pkFieldName := session.Statement.RefTable.PKColumns()[0].FieldName ididxes := make(map[int64]int) var ides []interface{} = make([]interface{}, 0) @@ -743,7 +744,18 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter } else { session.Engine.LogDebug("[xorm:cacheFind] cached bean:", tableName, id, bean) - sid := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName).Int() + pkField := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(pkFieldName) + + var sid int64 + switch pkField.Type().Kind() { + case reflect.Int32, reflect.Int, reflect.Int64: + sid = pkField.Int() + case reflect.Uint, reflect.Uint32, reflect.Uint64: + sid = int64(pkField.Uint()) + default: + return ErrCacheFailed + } + if sid != id { session.Engine.LogError("[xorm:cacheFind] error cache", id, sid, bean) return ErrCacheFailed @@ -795,7 +807,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter } } else if sliceValue.Kind() == reflect.Map { var key int64 - if table.PrimaryKey != "" { + if table.PrimaryKeys[0] != "" { key = ids[j] } else { key = int64(j) @@ -923,6 +935,7 @@ func (session *Session) Get(bean interface{}) (bool, error) { if err != nil { return false, err } + if len(resultsSlice) < 1 { return false, nil } @@ -1072,10 +1085,12 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } else if sliceValue.Kind() == reflect.Map { var key int64 - if table.PrimaryKey != "" { - x, err := strconv.ParseInt(string(results[table.PrimaryKey]), 10, 64) + // if there is only one pk, we can put the id as map key. + // TODO: should know if the column is ints + if len(table.PrimaryKeys) == 1 { + x, err := strconv.ParseInt(string(results[table.PrimaryKeys[0]]), 10, 64) if err != nil { - return errors.New("pk " + table.PrimaryKey + " as int64: " + err.Error()) + return errors.New("pk " + table.PrimaryKeys[0] + " as int64: " + err.Error()) } key = x } else { @@ -1258,7 +1273,6 @@ func row2map(rows *sql.Rows, fields []string) (resultsMap map[string][]byte, err // sql driver converted type back to []bytes then to ORM's fields for ii, key := range fields { rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii])) - //if row is null then ignore if rawValue.Interface() == nil { //fmt.Println("ignore ...", key, rawValue) @@ -1618,10 +1632,14 @@ func (session *Session) byte2Time(col *Column, data []byte) (outTime time.Time, ssd := strings.Split(sdata, " ") sdata = ssd[1] } - //if len(sdata) > 8 { - // sdata = sdata[len(sdata)-8:] - //} - fmt.Println(sdata) + + sdata = strings.TrimSpace(sdata) + //fmt.Println(sdata) + if session.Engine.dialect.DBType() == MYSQL && len(sdata) > 8 { + sdata = sdata[len(sdata)-8:] + } + //fmt.Println(sdata) + st := fmt.Sprintf("2006-01-02 %v", sdata) x, err = time.Parse("2006-01-02 15:04:05", st) } else { @@ -2082,14 +2100,14 @@ func (session *Session) value2Interface(col *Column, fieldValue reflect.Value) ( return fieldValue.Interface(), nil } if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { - if fieldTable.PrimaryKey != "" { - pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumn().FieldName) + if len(fieldTable.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) return pkField.Interface(), nil } else { - return 0, errors.New("no primary key") + return 0, fmt.Errorf("no primary key for col %v", col.Name) } } else { - return 0, errors.New(fmt.Sprintf("Unsupported type %v", fieldValue.Type())) + return 0, fmt.Errorf("Unsupported type %v\n", fieldValue.Type()) } case reflect.Complex64, reflect.Complex128: bytes, err := json.Marshal(fieldValue.Interface()) @@ -2195,7 +2213,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. - if session.Engine.DriverName != POSTGRES || table.PrimaryKey == "" { + if session.Engine.DriverName != POSTGRES || table.AutoIncrement == "" { res, err := session.exec(sql, args...) if err != nil { return 0, err @@ -2214,7 +2232,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if table.PrimaryKey == "" || table.PKColumn().SQLType.IsText() { + if table.AutoIncrement == "" { return res.RowsAffected() } @@ -2224,23 +2242,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return res.RowsAffected() } - pkValue := table.PKColumn().ValueOf(bean) - if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { + aiValue := table.AutoIncrColumn().ValueOf(bean) + if !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { return res.RowsAffected() } var v interface{} = id - switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + switch aiValue.Type().Kind() { + case reflect.Int32: + v = int32(id) + case reflect.Int: v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + case reflect.Uint32: + v = uint32(id) + case reflect.Uint64: + v = uint64(id) + case reflect.Uint: v = uint(id) } - pkValue.Set(reflect.ValueOf(v)) + aiValue.Set(reflect.ValueOf(v)) return res.RowsAffected() } else { - sql = sql + " RETURNING (id)" + //assert table.AutoIncrement != "" + sql = sql + " RETURNING " + session.Engine.Quote(table.AutoIncrement) res, err := session.query(sql, args...) if err != nil { return 0, err @@ -2263,25 +2288,31 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 0, errors.New("insert no error but not returned id") } - idByte := res[0][table.PrimaryKey] + idByte := res[0][table.AutoIncrement] id, err := strconv.ParseInt(string(idByte), 10, 64) if err != nil { return 1, err } - pkValue := table.PKColumn().ValueOf(bean) - if !pkValue.IsValid() || pkValue.Int() != 0 || !pkValue.CanSet() { + aiValue := table.AutoIncrColumn().ValueOf(bean) + if !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() { return 1, nil } var v interface{} = id - switch pkValue.Type().Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int: + switch aiValue.Type().Kind() { + case reflect.Int32: + v = int32(id) + case reflect.Int: v = int(id) - case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: + case reflect.Uint32: + v = uint32(id) + case reflect.Uint64: + v = uint64(id) + case reflect.Uint: v = uint(id) } - pkValue.Set(reflect.ValueOf(v)) + aiValue.Set(reflect.ValueOf(v)) return 1, nil } @@ -2304,14 +2335,14 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { } func (statement *Statement) convertUpdateSql(sql string) (string, string) { - if statement.RefTable == nil || statement.RefTable.PrimaryKey == "" { + if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { return "", "" } sqls := splitNNoCase(sql, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { return sqls[0], fmt.Sprintf("SELECT %v FROM %v", - statement.Engine.Quote(statement.RefTable.PrimaryKey), + statement.Engine.Quote(statement.RefTable.PrimaryKeys[0]), statement.Engine.Quote(statement.RefTable.Name)) } return "", "" @@ -2320,22 +2351,31 @@ func (statement *Statement) convertUpdateSql(sql string) (string, string) { var whereStr = sqls[1] //TODO: for postgres only, if any other database? - if strings.Contains(sqls[1], "$") { - dollers := strings.Split(sqls[1], "$") - whereStr = dollers[0] - for i, c := range dollers[1:] { - ccs := strings.SplitN(c, " ", 2) - whereStr += fmt.Sprintf("$%v %v", i+1, ccs[1]) + var paraStr string + if statement.Engine.dialect.DBType() == POSTGRES { + paraStr = "$" + } else if statement.Engine.dialect.DBType() == MSSQL { + paraStr = ":" + } + + if paraStr != "" { + if strings.Contains(sqls[1], paraStr) { + dollers := strings.Split(sqls[1], paraStr) + whereStr = dollers[0] + for i, c := range dollers[1:] { + ccs := strings.SplitN(c, " ", 2) + whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) + } } } return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", - statement.Engine.Quote(statement.RefTable.PrimaryKey), statement.Engine.Quote(statement.TableName()), + statement.Engine.Quote(statement.RefTable.PrimaryKeys[0]), statement.Engine.Quote(statement.TableName()), whereStr) } func (session *Session) cacheInsert(tables ...string) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } @@ -2351,7 +2391,7 @@ func (session *Session) cacheInsert(tables ...string) error { } func (session *Session) cacheUpdate(sql string, args ...interface{}) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } @@ -2389,7 +2429,7 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") } else { id, err = strconv.ParseInt(string(v), 10, 64) @@ -2638,7 +2678,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } func (session *Session) cacheDelete(sql string, args ...interface{}) error { - if session.Statement.RefTable == nil || session.Statement.RefTable.PrimaryKey == "" { + if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { return ErrCacheFailed } @@ -2663,7 +2703,7 @@ func (session *Session) cacheDelete(sql string, args ...interface{}) error { if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKey]; !ok { + if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { return errors.New("no id") } else { id, err = strconv.ParseInt(string(v), 10, 64) diff --git a/sqlite3.go b/sqlite3.go index 84a9d1b0..d52b2966 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -190,16 +190,19 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { indexes := make(map[string]*Index, 0) for _, record := range res { - var sql string index := new(Index) - for name, content := range record { - if name == "sql" { - sql = string(content) - } + sql := string(record["sql"]) + + if sql == "" { + continue } nNStart := strings.Index(sql, "INDEX") nNEnd := strings.Index(sql, "ON") + if nNStart == -1 || nNEnd == -1 { + continue + } + indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") //fmt.Println(indexName) if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { diff --git a/statement.go b/statement.go index 59c9421f..18834c1a 100644 --- a/statement.go +++ b/statement.go @@ -335,11 +335,15 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, } else { engine.autoMapType(fieldValue.Type()) if table, ok := engine.Tables[fieldValue.Type()]; ok { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) - if pkField.Int() != 0 { - val = pkField.Interface() + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + if pkField.Int() != 0 { + val = pkField.Interface() + } else { + continue + } } else { - continue + //TODO: how to handler? } } else { val = fieldValue.Interface() @@ -753,9 +757,10 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} statement.ConditionStr = strings.Join(colNames, " AND ") statement.BeanArgs = args - var id string = "*" - if table.PrimaryKey != "" { - id = statement.Engine.Quote(table.PrimaryKey) + // count(index fieldname) > count(0) > count(*) + var id string = "0" + if len(table.PrimaryKeys) == 1 { + id = statement.Engine.Quote(table.PrimaryKeys[0]) } return statement.genSelectSql(fmt.Sprintf("COUNT(%v) AS %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) } diff --git a/table.go b/table.go index aac87528..1eee795f 100644 --- a/table.go +++ b/table.go @@ -339,16 +339,17 @@ func (col *Column) ValueOf(bean interface{}) reflect.Value { // database table type Table struct { - Name string - Type reflect.Type - ColumnsSeq []string - Columns map[string]*Column - Indexes map[string]*Index - PrimaryKey string - Created map[string]bool - Updated string - Version string - Cacher Cacher + Name string + Type reflect.Type + ColumnsSeq []string + Columns map[string]*Column + Indexes map[string]*Index + PrimaryKeys []string + AutoIncrement string + Created map[string]bool + Updated string + Version string + Cacher Cacher } /* @@ -362,8 +363,16 @@ func NewTable(name string, t reflect.Type) *Table { }*/ // if has primary key, return column -func (table *Table) PKColumn() *Column { - return table.Columns[table.PrimaryKey] +func (table *Table) PKColumns() []*Column { + columns := make([]*Column, 0) + for _, name := range table.PrimaryKeys { + columns = append(columns, table.Columns[name]) + } + return columns +} + +func (table *Table) AutoIncrColumn() *Column { + return table.Columns[table.AutoIncrement] } func (table *Table) VersionColumn() *Column { @@ -375,7 +384,10 @@ func (table *Table) AddColumn(col *Column) { table.ColumnsSeq = append(table.ColumnsSeq, col.Name) table.Columns[col.Name] = col if col.IsPrimaryKey { - table.PrimaryKey = col.Name + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + } + if col.IsAutoIncrement { + table.AutoIncrement = col.Name } if col.IsCreated { table.Created[col.Name] = true @@ -408,8 +420,21 @@ func (table *Table) genCols(session *Session, bean interface{}, useCol bool, inc } fieldValue := col.ValueOf(bean) - if col.IsAutoIncrement && fieldValue.Int() == 0 { - continue + if col.IsAutoIncrement { + switch fieldValue.Type().Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64: + if fieldValue.Uint() == 0 { + continue + } + case reflect.String: + if len(fieldValue.String()) == 0 { + continue + } + } } if session.Statement.ColumnStr != "" {