feat: passed ci for cockroach

1. support auto inrc
2. sequence mode
This commit is contained in:
luoji 2024-11-28 19:57:31 +08:00
parent 096563cbf1
commit 824b65740b
8 changed files with 110 additions and 11 deletions

View File

@ -28,6 +28,9 @@ type URI struct {
Raddr string Raddr string
Timeout time.Duration Timeout time.Duration
Schema string Schema string
// for cockrocah
Serial string
} }
// SetSchema set schema // SetSchema set schema
@ -79,6 +82,7 @@ type Dialect interface {
CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error)
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
DropSequenceSQL(seqName string) (string, error) DropSequenceSQL(seqName string) (string, error)
NextvalSequenceSQL(seqName string) string
GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error) GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error)
IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error) IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error)
@ -170,6 +174,10 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) {
return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil
} }
func (db *Base) NextvalSequenceSQL(seqName string) string {
return seqName + ".nextval"
}
// DropTableSQL returns drop table SQL // DropTableSQL returns drop table SQL
func (db *Base) DropTableSQL(tableName string) (string, bool) { func (db *Base) DropTableSQL(tableName string) (string, bool) {
quote := db.dialect.Quoter().Quote quote := db.dialect.Quoter().Quote

View File

@ -14,6 +14,7 @@ import (
"strings" "strings"
"xorm.io/xorm/v2/internal/core" "xorm.io/xorm/v2/internal/core"
"xorm.io/xorm/v2/internal/utils"
"xorm.io/xorm/v2/schemas" "xorm.io/xorm/v2/schemas"
) )
@ -913,12 +914,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
res = schemas.Boolean res = schemas.Boolean
return res return res
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt: case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
if c.IsAutoIncrement { if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
return schemas.Serial return schemas.Serial
} }
return schemas.Integer return schemas.Integer
case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt: case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt:
if c.IsAutoIncrement { if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
return schemas.BigSerial return schemas.BigSerial
} }
return schemas.BigInt return schemas.BigInt
@ -947,7 +948,7 @@ func (db *postgres) SQLType(c *schemas.Column) string {
case schemas.Double, schemas.UnsignedFloat: case schemas.Double, schemas.UnsignedFloat:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
return schemas.Serial return schemas.Serial
} }
res = t res = t
@ -969,8 +970,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
} }
func (db *postgres) Features() *DialectFeatures { func (db *postgres) Features() *DialectFeatures {
var autoincrMode = IncrAutoincrMode
if db.uri.Serial == "sql_sequence" {
autoincrMode = SequenceAutoincrMode
}
return &DialectFeatures{ return &DialectFeatures{
AutoincrMode: IncrAutoincrMode, AutoincrMode: autoincrMode,
} }
} }
@ -1357,6 +1362,17 @@ func (db *postgres) GetIndexes(ctx context.Context, queryer core.Queryer, tableN
} }
func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
// compatible with sql sequence of cockroach
if db.dialect.URI().Serial == "sql_sequence" {
for _, col := range table.Columns() {
if col.IsAutoIncrement && col.IsPrimaryKey {
col.DefaultIsEmpty = false
col.Default = db.NextvalSequenceSQL(utils.SeqName(tableName))
break
}
}
}
quoter := db.dialect.Quoter() quoter := db.dialect.Quoter()
if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") { if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") {
tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName) tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName)
@ -1384,6 +1400,37 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
return createTableSQL + commentSQL, true, nil return createTableSQL + commentSQL, true, nil
} }
func (db *postgres) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) {
return fmt.Sprintf(`CREATE SEQUENCE %s
minvalue 1
start with 1
increment by 1`, seqName), nil
}
func (db *postgres) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
var cnt int
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM pg_class WHERE relkind = 'S' and relname = $1", strings.ToLower(seqName))
if err != nil {
return false, err
}
defer rows.Close()
if !rows.Next() {
if rows.Err() != nil {
return false, rows.Err()
}
return false, errors.New("query sequence failed")
}
if err := rows.Scan(&cnt); err != nil {
return false, err
}
return cnt > 0, nil
}
func (db *postgres) NextvalSequenceSQL(seqName string) string {
return "nextval('" + seqName + "')"
}
func (db *postgres) Filters() []Filter { func (db *postgres) Filters() []Filter {
return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}} return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}}
} }
@ -1413,6 +1460,19 @@ func parseURL(connstr string) (string, error) {
return "", nil return "", nil
} }
func queryURL(connstr, key string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
return "", err
}
if u.Scheme != "postgresql" && u.Scheme != "postgres" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
return u.Query().Get(key), nil
}
func parseOpts(urlStr string, o values) error { func parseOpts(urlStr string, o values) error {
if len(urlStr) == 0 { if len(urlStr) == 0 {
return fmt.Errorf("invalid options: %s", urlStr) return fmt.Errorf("invalid options: %s", urlStr)
@ -1505,6 +1565,11 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.Serial, err = queryURL(dataSourceName, "experimental_serial_normalization")
if err != nil {
return nil, err
}
} else { } else {
o := make(values) o := make(values)
err = parseOpts(dataSourceName, o) err = parseOpts(dataSourceName, o)

View File

@ -423,6 +423,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
} }
if tp[0] == schemas.POSTGRES { if tp[0] == schemas.POSTGRES {
destURI.Schema = engine.dialect.URI().Schema destURI.Schema = engine.dialect.URI().Schema
destURI.Serial = uri.Serial
} }
if err := dstDialect.Init(&destURI); err != nil { if err := dstDialect.Init(&destURI); err != nil {
return err return err
@ -760,7 +761,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
} }
// FIXME: Hack for postgres // FIXME: Hack for postgres
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil { if dstDialect.URI().DBType == schemas.POSTGRES && dstDialect.URI().Serial != "sql_sequence" && table.AutoIncrColumn() != nil {
_, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n") _, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n")
if err != nil { if err != nil {
return err return err

View File

@ -130,7 +130,8 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr any) (int64, error) {
if i == 0 { if i == 0 {
colNames = append(colNames, col.Name) colNames = append(colNames, col.Name)
} }
colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval")
colPlaces = append(colPlaces, session.engine.dialect.NextvalSequenceSQL(utils.SeqName(tableName)))
} }
continue continue
} }

View File

@ -56,9 +56,23 @@ func (statement *Statement) writeOrderCond(orderCondWriter *builder.BytesWriter,
return err return err
} }
} }
if statement.dialect.URI().Serial == "sql_sequence" {
cols := statement.RefTable.Columns()
if len(cols) == 0 {
return fmt.Errorf("no column")
}
colName := cols[0].Name
if _, err := fmt.Fprintf(orderCondWriter, colName+" IN (SELECT "+colName+" FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
return err
}
} else {
if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil { if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
return err return err
} }
}
orderCondWriter.Append(orderSQLWriter.Args()...) orderCondWriter.Append(orderSQLWriter.Args()...)
return nil return nil
case schemas.SQLITE: case schemas.SQLITE:

View File

@ -103,7 +103,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string,
return "", nil, err return "", nil, err
} }
} }
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil {
return "", nil, err return "", nil, err
} }
} }
@ -143,7 +143,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string,
return "", nil, err return "", nil, err
} }
} }
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil { if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil {
return "", nil, err return "", nil, err
} }
} }

View File

@ -233,7 +233,9 @@ func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error {
return nil return nil
} }
if statement.dialect.URI().DBType != schemas.MYSQL && statement.dialect.URI().DBType != schemas.POSTGRES { if statement.dialect.URI().DBType != schemas.MYSQL &&
!(statement.dialect.URI().DBType == schemas.POSTGRES &&
statement.dialect.URI().Serial != "sql_sequence") {
return errors.New("only support mysql and postgres for update") return errors.New("only support mysql and postgres for update")
} }
_, err := fmt.Fprint(w, " FOR UPDATE") _, err := fmt.Fprint(w, " FOR UPDATE")

View File

@ -94,6 +94,14 @@ func createEngine(dbType, connStr string) error {
} }
} }
db.Close() db.Close()
u, err := url.Parse(connStr)
if err != nil {
return err
}
if u.Query().Get("experimental_serial_normalization") == "sql_sequence" {
*ignoreSelectUpdate = true
}
case schemas.MYSQL: case schemas.MYSQL:
db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql")) db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql"))
if err != nil { if err != nil {