feat: passed ci for cockroach
1. support auto inrc 2. sequence mode
This commit is contained in:
parent
096563cbf1
commit
824b65740b
|
@ -28,6 +28,9 @@ type URI struct {
|
||||||
Raddr string
|
Raddr string
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
Schema string
|
Schema string
|
||||||
|
|
||||||
|
// for cockrocah
|
||||||
|
Serial string
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSchema set schema
|
// SetSchema set schema
|
||||||
|
@ -79,6 +82,7 @@ type Dialect interface {
|
||||||
CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error)
|
CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error)
|
||||||
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
|
IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error)
|
||||||
DropSequenceSQL(seqName string) (string, error)
|
DropSequenceSQL(seqName string) (string, error)
|
||||||
|
NextvalSequenceSQL(seqName string) string
|
||||||
|
|
||||||
GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error)
|
GetColumns(ctx context.Context, queryer core.Queryer, tableName string) ([]string, map[string]*schemas.Column, error)
|
||||||
IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error)
|
IsColumnExist(ctx context.Context, queryer core.Queryer, tableName string, colName string) (bool, error)
|
||||||
|
@ -170,6 +174,10 @@ func (db *Base) DropSequenceSQL(seqName string) (string, error) {
|
||||||
return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil
|
return fmt.Sprintf("DROP SEQUENCE %s", seqName), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *Base) NextvalSequenceSQL(seqName string) string {
|
||||||
|
return seqName + ".nextval"
|
||||||
|
}
|
||||||
|
|
||||||
// DropTableSQL returns drop table SQL
|
// DropTableSQL returns drop table SQL
|
||||||
func (db *Base) DropTableSQL(tableName string) (string, bool) {
|
func (db *Base) DropTableSQL(tableName string) (string, bool) {
|
||||||
quote := db.dialect.Quoter().Quote
|
quote := db.dialect.Quoter().Quote
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"xorm.io/xorm/v2/internal/core"
|
"xorm.io/xorm/v2/internal/core"
|
||||||
|
"xorm.io/xorm/v2/internal/utils"
|
||||||
"xorm.io/xorm/v2/schemas"
|
"xorm.io/xorm/v2/schemas"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -913,12 +914,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
|
||||||
res = schemas.Boolean
|
res = schemas.Boolean
|
||||||
return res
|
return res
|
||||||
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
|
case schemas.MediumInt, schemas.Int, schemas.Integer, schemas.UnsignedMediumInt, schemas.UnsignedSmallInt:
|
||||||
if c.IsAutoIncrement {
|
if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
|
||||||
return schemas.Serial
|
return schemas.Serial
|
||||||
}
|
}
|
||||||
return schemas.Integer
|
return schemas.Integer
|
||||||
case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt:
|
case schemas.BigInt, schemas.UnsignedBigInt, schemas.UnsignedInt:
|
||||||
if c.IsAutoIncrement {
|
if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
|
||||||
return schemas.BigSerial
|
return schemas.BigSerial
|
||||||
}
|
}
|
||||||
return schemas.BigInt
|
return schemas.BigInt
|
||||||
|
@ -947,7 +948,7 @@ func (db *postgres) SQLType(c *schemas.Column) string {
|
||||||
case schemas.Double, schemas.UnsignedFloat:
|
case schemas.Double, schemas.UnsignedFloat:
|
||||||
return "DOUBLE PRECISION"
|
return "DOUBLE PRECISION"
|
||||||
default:
|
default:
|
||||||
if c.IsAutoIncrement {
|
if c.IsAutoIncrement && db.dialect.URI().Serial != "sql_sequence" {
|
||||||
return schemas.Serial
|
return schemas.Serial
|
||||||
}
|
}
|
||||||
res = t
|
res = t
|
||||||
|
@ -969,8 +970,12 @@ func (db *postgres) SQLType(c *schemas.Column) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Features() *DialectFeatures {
|
func (db *postgres) Features() *DialectFeatures {
|
||||||
|
var autoincrMode = IncrAutoincrMode
|
||||||
|
if db.uri.Serial == "sql_sequence" {
|
||||||
|
autoincrMode = SequenceAutoincrMode
|
||||||
|
}
|
||||||
return &DialectFeatures{
|
return &DialectFeatures{
|
||||||
AutoincrMode: IncrAutoincrMode,
|
AutoincrMode: autoincrMode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1357,6 +1362,17 @@ func (db *postgres) GetIndexes(ctx context.Context, queryer core.Queryer, tableN
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
|
func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
|
||||||
|
// compatible with sql sequence of cockroach
|
||||||
|
if db.dialect.URI().Serial == "sql_sequence" {
|
||||||
|
for _, col := range table.Columns() {
|
||||||
|
if col.IsAutoIncrement && col.IsPrimaryKey {
|
||||||
|
col.DefaultIsEmpty = false
|
||||||
|
col.Default = db.NextvalSequenceSQL(utils.SeqName(tableName))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
quoter := db.dialect.Quoter()
|
quoter := db.dialect.Quoter()
|
||||||
if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") {
|
if len(db.getSchema()) != 0 && !strings.Contains(tableName, ".") {
|
||||||
tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName)
|
tableName = fmt.Sprintf("%s.%s", db.getSchema(), tableName)
|
||||||
|
@ -1384,6 +1400,37 @@ func (db *postgres) CreateTableSQL(ctx context.Context, queryer core.Queryer, ta
|
||||||
return createTableSQL + commentSQL, true, nil
|
return createTableSQL + commentSQL, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *postgres) CreateSequenceSQL(ctx context.Context, queryer core.Queryer, seqName string) (string, error) {
|
||||||
|
return fmt.Sprintf(`CREATE SEQUENCE %s
|
||||||
|
minvalue 1
|
||||||
|
start with 1
|
||||||
|
increment by 1`, seqName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *postgres) IsSequenceExist(ctx context.Context, queryer core.Queryer, seqName string) (bool, error) {
|
||||||
|
var cnt int
|
||||||
|
rows, err := queryer.QueryContext(ctx, "SELECT COUNT(*) FROM pg_class WHERE relkind = 'S' and relname = $1", strings.ToLower(seqName))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
if rows.Err() != nil {
|
||||||
|
return false, rows.Err()
|
||||||
|
}
|
||||||
|
return false, errors.New("query sequence failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rows.Scan(&cnt); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cnt > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *postgres) NextvalSequenceSQL(seqName string) string {
|
||||||
|
return "nextval('" + seqName + "')"
|
||||||
|
}
|
||||||
|
|
||||||
func (db *postgres) Filters() []Filter {
|
func (db *postgres) Filters() []Filter {
|
||||||
return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}}
|
return []Filter{&postgresSeqFilter{Prefix: "$", Start: 1}}
|
||||||
}
|
}
|
||||||
|
@ -1413,6 +1460,19 @@ func parseURL(connstr string) (string, error) {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func queryURL(connstr, key string) (string, error) {
|
||||||
|
u, err := url.Parse(connstr)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != "postgresql" && u.Scheme != "postgres" {
|
||||||
|
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.Query().Get(key), nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseOpts(urlStr string, o values) error {
|
func parseOpts(urlStr string, o values) error {
|
||||||
if len(urlStr) == 0 {
|
if len(urlStr) == 0 {
|
||||||
return fmt.Errorf("invalid options: %s", urlStr)
|
return fmt.Errorf("invalid options: %s", urlStr)
|
||||||
|
@ -1505,6 +1565,11 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*URI, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db.Serial, err = queryURL(dataSourceName, "experimental_serial_normalization")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
o := make(values)
|
o := make(values)
|
||||||
err = parseOpts(dataSourceName, o)
|
err = parseOpts(dataSourceName, o)
|
||||||
|
|
|
@ -423,6 +423,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
|
||||||
}
|
}
|
||||||
if tp[0] == schemas.POSTGRES {
|
if tp[0] == schemas.POSTGRES {
|
||||||
destURI.Schema = engine.dialect.URI().Schema
|
destURI.Schema = engine.dialect.URI().Schema
|
||||||
|
destURI.Serial = uri.Serial
|
||||||
}
|
}
|
||||||
if err := dstDialect.Init(&destURI); err != nil {
|
if err := dstDialect.Init(&destURI); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -760,7 +761,7 @@ func (engine *Engine) dumpTables(ctx context.Context, tables []*schemas.Table, w
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: Hack for postgres
|
// FIXME: Hack for postgres
|
||||||
if dstDialect.URI().DBType == schemas.POSTGRES && table.AutoIncrColumn() != nil {
|
if dstDialect.URI().DBType == schemas.POSTGRES && dstDialect.URI().Serial != "sql_sequence" && table.AutoIncrColumn() != nil {
|
||||||
_, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n")
|
_, err = io.WriteString(w, "SELECT setval('"+dstTableName+"_id_seq', COALESCE((SELECT MAX("+table.AutoIncrColumn().Name+") + 1 FROM "+dstDialect.Quoter().Quote(dstTableName)+"), 1), false);\n")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -130,7 +130,8 @@ func (session *Session) insertMultipleStruct(rowsSlicePtr any) (int64, error) {
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
colNames = append(colNames, col.Name)
|
colNames = append(colNames, col.Name)
|
||||||
}
|
}
|
||||||
colPlaces = append(colPlaces, utils.SeqName(tableName)+".nextval")
|
|
||||||
|
colPlaces = append(colPlaces, session.engine.dialect.NextvalSequenceSQL(utils.SeqName(tableName)))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,9 +56,23 @@ func (statement *Statement) writeOrderCond(orderCondWriter *builder.BytesWriter,
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
|
|
||||||
return err
|
if statement.dialect.URI().Serial == "sql_sequence" {
|
||||||
|
cols := statement.RefTable.Columns()
|
||||||
|
if len(cols) == 0 {
|
||||||
|
return fmt.Errorf("no column")
|
||||||
|
}
|
||||||
|
colName := cols[0].Name
|
||||||
|
|
||||||
|
if _, err := fmt.Fprintf(orderCondWriter, colName+" IN (SELECT "+colName+" FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := fmt.Fprintf(orderCondWriter, "ctid IN (SELECT ctid FROM %s%s)", tableName, orderSQLWriter.String()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
orderCondWriter.Append(orderSQLWriter.Args()...)
|
orderCondWriter.Append(orderSQLWriter.Args()...)
|
||||||
return nil
|
return nil
|
||||||
case schemas.SQLITE:
|
case schemas.SQLITE:
|
||||||
|
|
|
@ -103,7 +103,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string,
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
|
if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -143,7 +143,7 @@ func (statement *Statement) GenInsertSQL(colNames []string, args []any) (string,
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if _, err := buf.WriteString(utils.SeqName(tableName) + ".nextval"); err != nil {
|
if _, err := buf.WriteString(statement.dialect.NextvalSequenceSQL(utils.SeqName(tableName))); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -233,7 +233,9 @@ func (statement *Statement) writeForUpdate(w *builder.BytesWriter) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if statement.dialect.URI().DBType != schemas.MYSQL && statement.dialect.URI().DBType != schemas.POSTGRES {
|
if statement.dialect.URI().DBType != schemas.MYSQL &&
|
||||||
|
!(statement.dialect.URI().DBType == schemas.POSTGRES &&
|
||||||
|
statement.dialect.URI().Serial != "sql_sequence") {
|
||||||
return errors.New("only support mysql and postgres for update")
|
return errors.New("only support mysql and postgres for update")
|
||||||
}
|
}
|
||||||
_, err := fmt.Fprint(w, " FOR UPDATE")
|
_, err := fmt.Fprint(w, " FOR UPDATE")
|
||||||
|
|
|
@ -94,6 +94,14 @@ func createEngine(dbType, connStr string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
db.Close()
|
db.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(connStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if u.Query().Get("experimental_serial_normalization") == "sql_sequence" {
|
||||||
|
*ignoreSelectUpdate = true
|
||||||
|
}
|
||||||
case schemas.MYSQL:
|
case schemas.MYSQL:
|
||||||
db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql"))
|
db, err := sql.Open(dbType, strings.ReplaceAll(connStr, "xorm_test", "mysql"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue