Change schemas.Column to use int64

The largest size of a column in SQL is above the largest integer that
can be stored in an int/int32. Running go on a 32bit machine will result
in ints being mapped to int32 and thus interpretting the schema will
fail.

In this PR we change the schema.Column to use int64 which will allow
Gitea to fix https://github.com/go-gitea/gitea/issues/20161

Signed-off-by: Andrew Thornton <art27@cantab.net>
This commit is contained in:
Andrew Thornton 2022-07-12 18:51:58 +01:00
parent f469d88166
commit 72bfadeefd
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
9 changed files with 86 additions and 86 deletions

View File

@ -622,9 +622,9 @@ func (db *dameng) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -729,11 +729,11 @@ func (db *dameng) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *dameng) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = damengQuoter q := damengQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = damengQuoter q := damengQuoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -927,7 +927,7 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
var ( var (
ignore bool ignore bool
dt string dt string
len1, len2 int len1, len2 int64
) )
dts := strings.Split(dataType.String, "(") dts := strings.Split(dataType.String, "(")
@ -935,10 +935,10 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
if len(dts) > 1 { if len(dts) > 1 {
lens := strings.Split(dts[1][:len(dts[1])-1], ",") lens := strings.Split(dts[1][:len(dts[1])-1], ",")
if len(lens) > 1 { if len(lens) > 1 {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
len2, _ = strconv.Atoi(lens[1]) len2, _ = strconv.ParseInt(lens[1], 10, 64)
} else { } else {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
} }
} }
@ -972,9 +972,9 @@ func (db *dameng) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
} }
if col.SQLType.Name == "TIMESTAMP" { if col.SQLType.Name == "TIMESTAMP" {
col.Length = int(dataScale.Int64) col.Length = dataScale.Int64
} else { } else {
col.Length = int(dataLen.Int64) col.Length = dataLen.Int64
} }
if col.SQLType.IsTime() { if col.SQLType.IsTime() {
@ -1140,8 +1140,8 @@ func (d *damengDriver) GenScanResult(colType string) (interface{}, error) {
} }
func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error { func (d *damengDriver) Scan(ctx *ScanContext, rows *core.Rows, types []*sql.ColumnType, vv ...interface{}) error {
var scanResults = make([]interface{}, 0, len(types)) scanResults := make([]interface{}, 0, len(types))
var replaces = make([]bool, 0, len(types)) replaces := make([]bool, 0, len(types))
var err error var err error
for i, v := range vv { for i, v := range vv {
var replaced bool var replaced bool

View File

@ -229,7 +229,7 @@ func (db *mssql) Init(uri *URI) error {
func (db *mssql) SetParams(params map[string]string) { func (db *mssql) SetParams(params map[string]string) {
defaultVarchar, ok := params["DEFAULT_VARCHAR"] defaultVarchar, ok := params["DEFAULT_VARCHAR"]
if ok { if ok {
var t = strings.ToUpper(defaultVarchar) t := strings.ToUpper(defaultVarchar)
switch t { switch t {
case "NVARCHAR", "VARCHAR": case "NVARCHAR", "VARCHAR":
db.defaultVarchar = t db.defaultVarchar = t
@ -242,7 +242,7 @@ func (db *mssql) SetParams(params map[string]string) {
defaultChar, ok := params["DEFAULT_CHAR"] defaultChar, ok := params["DEFAULT_CHAR"]
if ok { if ok {
var t = strings.ToUpper(defaultChar) t := strings.ToUpper(defaultChar)
switch t { switch t {
case "NCHAR", "CHAR": case "NCHAR", "CHAR":
db.defaultChar = t db.defaultChar = t
@ -375,9 +375,9 @@ func (db *mssql) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -403,11 +403,11 @@ func (db *mssql) IsReserved(name string) bool {
func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *mssql) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = mssqlQuoter q := mssqlQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = mssqlQuoter q := mssqlQuoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -475,7 +475,7 @@ func (db *mssql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, vdefault string var name, ctype, vdefault string
var maxLen, precision, scale int var maxLen, precision, scale int64
var nullable, isPK, defaultIsNull, isIncrement bool var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement) err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement)
if err != nil { if err != nil {

View File

@ -330,9 +330,9 @@ func (db *mysql) SQLType(c *schemas.Column) string {
} }
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
if isUnsigned { if isUnsigned {
@ -444,7 +444,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
// Remove the /* mariadb-5.3 */ suffix from coltypes // Remove the /* mariadb-5.3 */ suffix from coltypes
colName = strings.TrimSuffix(colName, "/* mariadb-5.3 */") colName = strings.TrimSuffix(colName, "/* mariadb-5.3 */")
colType = strings.ToUpper(colName) colType = strings.ToUpper(colName)
var len1, len2 int var len1, len2 int64
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
if colType == schemas.Enum && cts[1][0] == '\'' { // enum if colType == schemas.Enum && cts[1][0] == '\'' { // enum
@ -465,12 +465,12 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} }
} else { } else {
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.ParseInt(strings.TrimSpace(lens[0]), 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.ParseInt(lens[1], 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -479,7 +479,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} else { } else {
switch colType { switch colType {
case "MEDIUMTEXT", "LONGTEXT", "TEXT": case "MEDIUMTEXT", "LONGTEXT", "TEXT":
len1, err = strconv.Atoi(*maxLength) len1, err = strconv.ParseInt(*maxLength, 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -570,9 +570,9 @@ func (db *oracle) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -606,7 +606,7 @@ func (db *oracle) DropTableSQL(tableName string) (string, bool) {
} }
func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) { func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, table *schemas.Table, tableName string) (string, bool, error) {
var sql = "CREATE TABLE " sql := "CREATE TABLE "
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
} }
@ -641,11 +641,11 @@ func (db *oracle) CreateTableSQL(ctx context.Context, queryer core.Queryer, tabl
func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) { func (db *oracle) SetQuotePolicy(quotePolicy QuotePolicy) {
switch quotePolicy { switch quotePolicy {
case QuotePolicyNone: case QuotePolicyNone:
var q = oracleQuoter q := oracleQuoter
q.IsReserved = schemas.AlwaysNoReserve q.IsReserved = schemas.AlwaysNoReserve
db.quoter = q db.quoter = q
case QuotePolicyReserved: case QuotePolicyReserved:
var q = oracleQuoter q := oracleQuoter
q.IsReserved = db.IsReserved q.IsReserved = db.IsReserved
db.quoter = q db.quoter = q
case QuotePolicyAlways: case QuotePolicyAlways:
@ -690,7 +690,7 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
col.Indexes = make(map[string]int) col.Indexes = make(map[string]int)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string var colName, colDefault, nullable, dataType, dataPrecision, dataScale *string
var dataLen int var dataLen int64
err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision, err = rows.Scan(&colName, &colDefault, &dataType, &dataLen, &dataPrecision,
&dataScale, &nullable) &dataScale, &nullable)
@ -713,16 +713,16 @@ func (db *oracle) GetColumns(queryer core.Queryer, ctx context.Context, tableNam
var ignore bool var ignore bool
var dt string var dt string
var len1, len2 int var len1, len2 int64
dts := strings.Split(*dataType, "(") dts := strings.Split(*dataType, "(")
dt = dts[0] dt = dts[0]
if len(dts) > 1 { if len(dts) > 1 {
lens := strings.Split(dts[1][:len(dts[1])-1], ",") lens := strings.Split(dts[1][:len(dts[1])-1], ",")
if len(lens) > 1 { if len(lens) > 1 {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
len2, _ = strconv.Atoi(lens[1]) len2, _ = strconv.ParseInt(lens[1], 10, 64)
} else { } else {
len1, _ = strconv.Atoi(lens[0]) len1, _ = strconv.ParseInt(lens[0], 10, 64)
} }
} }

View File

@ -934,9 +934,9 @@ func (db *postgres) SQLType(c *schemas.Column) string {
hasLen2 := (c.Length2 > 0) hasLen2 := (c.Length2 > 0)
if hasLen2 { if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + "," + strconv.FormatInt(c.Length2, 10) + ")"
} else if hasLen1 { } else if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.FormatInt(c.Length, 10) + ")"
} }
return res return res
} }
@ -1110,9 +1110,9 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
return nil, nil, err return nil, nil, err
} }
var maxLen int var maxLen int64
if maxLenStr != nil { if maxLenStr != nil {
maxLen, err = strconv.Atoi(*maxLenStr) maxLen, err = strconv.ParseInt(*maxLenStr, 10, 64)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1186,7 +1186,7 @@ WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s A
startIdx := strings.Index(strings.ToLower(dataType), "string(") startIdx := strings.Index(strings.ToLower(dataType), "string(")
if startIdx != -1 && strings.HasSuffix(dataType, ")") { if startIdx != -1 && strings.HasSuffix(dataType, ")") {
length := dataType[startIdx+8 : len(dataType)-1] length := dataType[startIdx+8 : len(dataType)-1]
l, _ := strconv.Atoi(length) l, _ := strconv.ParseInt(length, 10, 64)
col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: "STRING", DefaultLength: l, DefaultLength2: 0}
} else { } else {
col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0} col.SQLType = schemas.SQLType{Name: strings.ToUpper(dataType), DefaultLength: 0, DefaultLength2: 0}

View File

@ -23,7 +23,7 @@ func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.C
} }
} }
var tmZone = dbLocation tmZone := dbLocation
if col.TimeZone != nil { if col.TimeZone != nil {
tmZone = col.TimeZone tmZone = col.TimeZone
} }
@ -34,15 +34,17 @@ func FormatColumnTime(dialect Dialect, dbLocation *time.Location, col *schemas.C
case schemas.Date: case schemas.Date:
return t.Format("2006-01-02"), nil return t.Format("2006-01-02"), nil
case schemas.Time: case schemas.Time:
var layout = "15:04:05" layout := "15:04:05"
if col.Length > 0 { if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length) // we can use int(...) casting here as it's very unlikely to a huge sized field
layout += "." + strings.Repeat("0", int(col.Length))
} }
return t.Format(layout), nil return t.Format(layout), nil
case schemas.DateTime, schemas.TimeStamp: case schemas.DateTime, schemas.TimeStamp:
var layout = "2006-01-02 15:04:05" layout := "2006-01-02 15:04:05"
if col.Length > 0 { if col.Length > 0 {
layout += "." + strings.Repeat("0", col.Length) // we can use int(...) casting here as it's very unlikely to a huge sized field
layout += "." + strings.Repeat("0", int(col.Length))
} }
return t.Format(layout), nil return t.Format(layout), nil
case schemas.Varchar: case schemas.Varchar:

View File

@ -26,8 +26,8 @@ type Column struct {
FieldIndex []int // Available only when parsed from a struct FieldIndex []int // Available only when parsed from a struct
SQLType SQLType SQLType SQLType
IsJSON bool IsJSON bool
Length int Length int64
Length2 int Length2 int64
Nullable bool Nullable bool
Default string Default string
Indexes map[string]int Indexes map[string]int
@ -48,7 +48,7 @@ type Column struct {
} }
// NewColumn creates a new column // NewColumn creates a new column
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int64, nullable bool) *Column {
return &Column{ return &Column{
Name: name, Name: name,
IsJSON: sqlType.IsJson(), IsJSON: sqlType.IsJson(),
@ -82,7 +82,7 @@ func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
// ValueOfV returns column's filed of struct's value accept reflevt value // ValueOfV returns column's filed of struct's value accept reflevt value
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
var v = *dataStruct v := *dataStruct
for _, i := range col.FieldIndex { for _, i := range col.FieldIndex {
if v.Kind() == reflect.Ptr { if v.Kind() == reflect.Ptr {
if v.IsNil() { if v.IsNil() {

View File

@ -28,8 +28,8 @@ const (
// SQLType represents SQL types // SQLType represents SQL types
type SQLType struct { type SQLType struct {
Name string Name string
DefaultLength int DefaultLength int64
DefaultLength2 int DefaultLength2 int64
} }
// enumerates all columns types // enumerates all columns types

View File

@ -99,9 +99,8 @@ type Context struct {
// Handler describes tag handler for XORM // Handler describes tag handler for XORM
type Handler func(ctx *Context) error type Handler func(ctx *Context) error
var (
// defaultTagHandlers enumerates all the default tag handler // defaultTagHandlers enumerates all the default tag handler
defaultTagHandlers = map[string]Handler{ var defaultTagHandlers = map[string]Handler{
"-": IgnoreHandler, "-": IgnoreHandler,
"<-": OnlyFromDBTagHandler, "<-": OnlyFromDBTagHandler,
"->": OnlyToDBTagHandler, "->": OnlyToDBTagHandler,
@ -125,7 +124,6 @@ var (
"EXTENDS": ExtendsTagHandler, "EXTENDS": ExtendsTagHandler,
"UNSIGNED": UnsignedTagHandler, "UNSIGNED": UnsignedTagHandler,
} }
)
func init() { func init() {
for k := range schemas.SqlTypes { for k := range schemas.SqlTypes {
@ -312,16 +310,16 @@ func SQLTypeTagHandler(ctx *Context) error {
default: default:
var err error var err error
if len(ctx.params) == 2 { if len(ctx.params) == 2 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) ctx.col.Length, err = strconv.ParseInt(ctx.params[0], 10, 64)
if err != nil { if err != nil {
return err return err
} }
ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) ctx.col.Length2, err = strconv.ParseInt(ctx.params[1], 10, 64)
if err != nil { if err != nil {
return err return err
} }
} else if len(ctx.params) == 1 { } else if len(ctx.params) == 1 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) ctx.col.Length, err = strconv.ParseInt(ctx.params[0], 10, 64)
if err != nil { if err != nil {
return err return err
} }
@ -332,8 +330,8 @@ func SQLTypeTagHandler(ctx *Context) error {
// ExtendsTagHandler describes extends tag handler // ExtendsTagHandler describes extends tag handler
func ExtendsTagHandler(ctx *Context) error { func ExtendsTagHandler(ctx *Context) error {
var fieldValue = ctx.fieldValue fieldValue := ctx.fieldValue
var isPtr = false isPtr := false
switch fieldValue.Kind() { switch fieldValue.Kind() {
case reflect.Ptr: case reflect.Ptr:
f := fieldValue.Type().Elem() f := fieldValue.Type().Elem()
@ -355,7 +353,7 @@ func ExtendsTagHandler(ctx *Context) error {
col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName) col.FieldName = fmt.Sprintf("%v.%v", ctx.col.FieldName, col.FieldName)
col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...) col.FieldIndex = append(ctx.col.FieldIndex, col.FieldIndex...)
var tagPrefix = ctx.col.FieldName tagPrefix := ctx.col.FieldName
if len(ctx.params) > 0 { if len(ctx.params) > 0 {
col.Nullable = isPtr col.Nullable = isPtr
tagPrefix = strings.Trim(ctx.params[0], "'") tagPrefix = strings.Trim(ctx.params[0], "'")