feat: passed ci for cockroach

1. support auto inrc
2. sequence mode
This commit is contained in:
luoji 2024-11-28 19:57:31 +08:00
parent 096563cbf1
commit 824b65740b
8 changed files with 110 additions and 11 deletions

View File

@ -28,6 +28,9 @@ type URI struct {
Raddr string
Timeout time.Duration
Schema string
// for cockrocah
Serial string
}
// SetSchema set schema
@ -79,6 +82,7 @@ type Dialect interface {
CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error)
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
DropSequenceSQL(seqName string) (string, error)
NextvalSequenceSQL(seqName string) string
GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error)
@ -170,6 +174,10 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) {
return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil
}
func (db *Base) NextvalSequenceSQL(seqName string) string {
return seqName + ".nextval"
}
// DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote

View File

@ -14,6 +14,7 @@ import (
"strings"
"xorm.io/xorm/v2/internal/core"
"xorm.io/xorm/v2/internal/utils"
"xorm.io/xorm/v2/schemas"
)
@ -913,12 +914,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
res = schemas.Boolean
return res
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
if c.IsAutoIncrement {
if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
return schemas.Serial
}
return schemas.Integer
case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt:
if c.IsAutoIncrement {
if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
return schemas.BigSerial
}
return schemas.BigInt
@ -947,7 +948,7 @@ func (db *postgres) SQLType(c *schemas.Column) string {
case schemas.Double, schemas.UnsignedFloat:
return "DOUBLE PRECISION"
default:
if c.IsAutoIncrement {
if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
return schemas.Serial
}
res = t
@ -969,8 +970,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
}
func (db *postgres) Features() *DialectFeatures {
var autoincrMode = IncrAutoincrMode
if db.uri.Serial == "sql_sequence" {
autoincrMode = SequenceAutoincrMode
}
return &DialectFeatures{
AutoincrMode: IncrAutoincrMode,
AutoincrMode: autoincrMode,
}
}
@ -1357,6 +1362,17 @@ func (db *postgres) GetIndexes(ctx context.Context, queryer core.Queryer, tableN
}
func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
// compatible with sql sequence of cockroach
if db.dialect.URI().Serial == "sql_sequence" {
for _, col := range table.Columns() {
if col.IsAutoIncrement && col.IsPrimaryKey {
col.DefaultIsEmpty = false
col.Default = db.NextvalSequenceSQL(utils.SeqName(tableName))
break
}
}
}
quoter := db.dialect.Quoter()
if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") {
tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName)
@ -1384,6 +1400,37 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
return createTableSQL + commentSQL, true, nil
}
func (db *postgres) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) {
return fmt.Sprintf(`CREATE SEQUENCE %s
minvalue 1
start with 1
increment by 1`, seqName), nil
}
func (db *postgres) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
var cnt int
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM pg_class WHERE relkind = 'S' and relname = $1", strings.ToLower(seqName))
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return false, rows.Err()
}
return false, errors.New("query sequence failed")
}
if err := rows.Scan(&cnt); err != nil {
return false, err
}
return cnt > 0, nil
}
func (db *postgres) NextvalSequenceSQL(seqName string) string {
return "nextval('" + seqName + "')"
}
func (db *postgres) Filters() []Filter {
return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}}
}
@ -1413,6 +1460,19 @@ func parseURL(connstr string) (string, error) {
return "", nil
}
func queryURL(connstr, key string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
return "", err
}
if u.Scheme != "postgresql" && u.Scheme != "postgres" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
return u.Query().Get(key), nil
}
func parseOpts(urlStr string, o values) error {
if len(urlStr) == 0 {
return fmt.Errorf("invalid options: %s", urlStr)
@ -1505,6 +1565,11 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
if err != nil {
return nil, err
}
db.Serial, err = queryURL(dataSourceName, "experimental_serial_normalization")
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)

View File

@ -423,6 +423,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
}
if tp[0] == schemas.POSTGRES {
destURI.Schema = engine.dialect.URI().Schema
destURI.Serial = uri.Serial
}
if err := dstDialect.Init(&destURI); err != nil {
return err
@ -760,7 +761,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
}
// FIXME: Hack for postgres
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
if dstDialect.URI().DBType == schemas.POSTGRES && dstDialect.URI().Serial != "sql_sequence" && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n")
if err != nil {
return err

View File

@ -130,7 +130,8 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr any) (int64, error) {
if i == 0 {
colNames = append(colNames, col.Name)
}
colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval")
colPlaces = append(colPlaces, session.engine.dialect.NextvalSequenceSQL(utils.SeqName(tableName)))
}
continue
}

View File

@ -56,9 +56,23 @@ func (statement *Statement) writeOrderCond(orderCondWriter *builder.BytesWriter,
return err
}
}
if statement.dialect.URI().Serial == "sql_sequence" {
cols := statement.RefTable.Columns()
if len(cols) == 0 {
return fmt.Errorf("no column")
}
colName := cols[0].Name
if _, err := fmt.Fprintf(orderCondWriter, colName+" IN (SELECT "+colName+" FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
return err
}
} else {
if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
return err
}
}
orderCondWriter.Append(orderSQLWriter.Args()...)
return nil
case schemas.SQLITE:

View File

@ -103,7 +103,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string,
return "", nil, err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil {
return "", nil, err
}
}
@ -143,7 +143,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string,
return "", nil, err
}
}
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil {
return "", nil, err
}
}

View File

@ -233,7 +233,9 @@ func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error {
return nil
}
if statement.dialect.URI().DBType != schemas.MYSQL && statement.dialect.URI().DBType != schemas.POSTGRES {
if statement.dialect.URI().DBType != schemas.MYSQL &&
!(statement.dialect.URI().DBType == schemas.POSTGRES &&
statement.dialect.URI().Serial != "sql_sequence") {
return errors.New("only support mysql and postgres for update")
}
_, err := fmt.Fprint(w, " FOR UPDATE")

View File

@ -94,6 +94,14 @@ func createEngine(dbType, connStr string) error {
}
}
db.Close()
u, err := url.Parse(connStr)
if err != nil {
return err
}
if u.Query().Get("experimental_serial_normalization") == "sql_sequence" {
*ignoreSelectUpdate = true
}
case schemas.MYSQL:
db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql"))
if err != nil {