From 824b65740b302e2018accb145f66f3b2486c7df5 Mon Sep 17 00:00:00 2001 From: luoji Date: Thu, 28 Nov 2024 19:57:31 +0800 Subject: [PATCH] feat: passed ci for cockroach 1. support auto inrc 2. sequence mode --- dialects/dialect.go | 8 ++++ dialects/postgres.go | 73 +++++++++++++++++++++++++++++++++-- engine.go | 3 +- insert.go | 3 +- internal/statements/delete.go | 18 ++++++++- internal/statements/insert.go | 4 +- internal/statements/query.go | 4 +- tests/tests.go | 8 ++++ 8 files changed, 110 insertions(+), 11 deletions(-) diff --git a/dialects/dialect.go b/dialects/dialect.go index b1d26c63..0c2a1e5c 100644 --- a/dialects/dialect.go +++ b/dialects/dialect.go @@ -28,6 +28,9 @@ type URI struct { Raddr string Timeout time.Duration Schema string + + // for cockrocah + Serial string } // SetSchema set schema @@ -79,6 +82,7 @@ type Dialect interface { CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) DropSequenceSQL(seqName string) (string, error) + NextvalSequenceSQL(seqName string) string GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error) @@ -170,6 +174,10 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) { return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil } +func (db *Base) NextvalSequenceSQL(seqName string) string { + return seqName + ".nextval" +} + // DropTableSQL returns drop table SQL func (db *Base) DropTableSQL(tableName string) (string, bool) { quote := db.dialect.Quoter().Quote diff --git a/dialects/postgres.go b/dialects/postgres.go index e436dfb6..5251e505 100644 --- a/dialects/postgres.go +++ b/dialects/postgres.go @@ -14,6 +14,7 @@ import ( "strings" "xorm.io/xorm/v2/internal/core" + "xorm.io/xorm/v2/internal/utils" "xorm.io/xorm/v2/schemas" ) @@ -913,12 +914,12 @@ func (db *postgres) SQLType(c *schemas.Column) string { res = schemas.Boolean return res case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt: - if c.IsAutoIncrement { + if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" { return schemas.Serial } return schemas.Integer case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt: - if c.IsAutoIncrement { + if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" { return schemas.BigSerial } return schemas.BigInt @@ -947,7 +948,7 @@ func (db *postgres) SQLType(c *schemas.Column) string { case schemas.Double, schemas.UnsignedFloat: return "DOUBLE PRECISION" default: - if c.IsAutoIncrement { + if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" { return schemas.Serial } res = t @@ -969,8 +970,12 @@ func (db *postgres) SQLType(c *schemas.Column) string { } func (db *postgres) Features() *DialectFeatures { + var autoincrMode = IncrAutoincrMode + if db.uri.Serial == "sql_sequence" { + autoincrMode = SequenceAutoincrMode + } return &DialectFeatures{ - AutoincrMode: IncrAutoincrMode, + AutoincrMode: autoincrMode, } } @@ -1357,6 +1362,17 @@ func (db *postgres) GetIndexes(ctx context.Context, queryer core.Queryer, tableN } func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { + // compatible with sql sequence of cockroach + if db.dialect.URI().Serial == "sql_sequence" { + for _, col := range table.Columns() { + if col.IsAutoIncrement && col.IsPrimaryKey { + col.DefaultIsEmpty = false + col.Default = db.NextvalSequenceSQL(utils.SeqName(tableName)) + break + } + } + } + quoter := db.dialect.Quoter() if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") { tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName) @@ -1384,6 +1400,37 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta return createTableSQL + commentSQL, true, nil } +func (db *postgres) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) { + return fmt.Sprintf(`CREATE SEQUENCE %s + minvalue 1 + start with 1 + increment by 1`, seqName), nil +} + +func (db *postgres) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) { + var cnt int + rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM pg_class WHERE relkind = 'S' and relname = $1", strings.ToLower(seqName)) + if err != nil { + return false, err + } + defer rows.Close() + if !rows.Next() { + if rows.Err() != nil { + return false, rows.Err() + } + return false, errors.New("query sequence failed") + } + + if err := rows.Scan(&cnt); err != nil { + return false, err + } + return cnt > 0, nil +} + +func (db *postgres) NextvalSequenceSQL(seqName string) string { + return "nextval('" + seqName + "')" +} + func (db *postgres) Filters() []Filter { return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}} } @@ -1413,6 +1460,19 @@ func parseURL(connstr string) (string, error) { return "", nil } +func queryURL(connstr, key string) (string, error) { + u, err := url.Parse(connstr) + if err != nil { + return "", err + } + + if u.Scheme != "postgresql" && u.Scheme != "postgres" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + return u.Query().Get(key), nil +} + func parseOpts(urlStr string, o values) error { if len(urlStr) == 0 { return fmt.Errorf("invalid options: %s", urlStr) @@ -1505,6 +1565,11 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) { if err != nil { return nil, err } + + db.Serial, err = queryURL(dataSourceName, "experimental_serial_normalization") + if err != nil { + return nil, err + } } else { o := make(values) err = parseOpts(dataSourceName, o) diff --git a/engine.go b/engine.go index 432b92ed..afff0e8a 100644 --- a/engine.go +++ b/engine.go @@ -423,6 +423,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w } if tp[0] == schemas.POSTGRES { destURI.Schema = engine.dialect.URI().Schema + destURI.Serial = uri.Serial } if err := dstDialect.Init(&destURI); err != nil { return err @@ -760,7 +761,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w } // FIXME: Hack for postgres - if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { + if dstDialect.URI().DBType == schemas.POSTGRES && dstDialect.URI().Serial != "sql_sequence" && table.AutoIncrColumn() != nil { _, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n") if err != nil { return err diff --git a/insert.go b/insert.go index f62f326d..d49f7874 100644 --- a/insert.go +++ b/insert.go @@ -130,7 +130,8 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr any) (int64, error) { if i == 0 { colNames = append(colNames, col.Name) } - colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval") + + colPlaces = append(colPlaces, session.engine.dialect.NextvalSequenceSQL(utils.SeqName(tableName))) } continue } diff --git a/internal/statements/delete.go b/internal/statements/delete.go index 148efd6d..e4e22565 100644 --- a/internal/statements/delete.go +++ b/internal/statements/delete.go @@ -56,9 +56,23 @@ func (statement *Statement) writeOrderCond(orderCondWriter *builder.BytesWriter, return err } } - if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil { - return err + + if statement.dialect.URI().Serial == "sql_sequence" { + cols := statement.RefTable.Columns() + if len(cols) == 0 { + return fmt.Errorf("no column") + } + colName := cols[0].Name + + if _, err := fmt.Fprintf(orderCondWriter, colName+" IN (SELECT "+colName+" FROM %s%s)", tableName, orderSQLWriter.String()); err != nil { + return err + } + } else { + if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil { + return err + } } + orderCondWriter.Append(orderSQLWriter.Args()...) return nil case schemas.SQLITE: diff --git a/internal/statements/insert.go b/internal/statements/insert.go index 1b9ed5e6..dd55c9b2 100644 --- a/internal/statements/insert.go +++ b/internal/statements/insert.go @@ -103,7 +103,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string, return "", nil, err } } - if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil { return "", nil, err } } @@ -143,7 +143,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string, return "", nil, err } } - if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { + if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil { return "", nil, err } } diff --git a/internal/statements/query.go b/internal/statements/query.go index 9dd0857f..1d9a997c 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -233,7 +233,9 @@ func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error { return nil } - if statement.dialect.URI().DBType != schemas.MYSQL && statement.dialect.URI().DBType != schemas.POSTGRES { + if statement.dialect.URI().DBType != schemas.MYSQL && + !(statement.dialect.URI().DBType == schemas.POSTGRES && + statement.dialect.URI().Serial != "sql_sequence") { return errors.New("only support mysql and postgres for update") } _, err := fmt.Fprint(w, " FOR UPDATE") diff --git a/tests/tests.go b/tests/tests.go index f7706096..c83cb586 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -94,6 +94,14 @@ func createEngine(dbType, connStr string) error { } } db.Close() + + u, err := url.Parse(connStr) + if err != nil { + return err + } + if u.Query().Get("experimental_serial_normalization") == "sql_sequence" { + *ignoreSelectUpdate = true + } case schemas.MYSQL: db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql")) if err != nil {