add column type unsigned

This commit is contained in:
Jerry 2020-10-10 01:18:03 +00:00 committed by Jerry
parent 0c1b815227
commit 50deb79581
6 changed files with 139 additions and 83 deletions

View File

@ -354,26 +354,32 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
cts := strings.Split(colType, "(") cts := strings.Split(colType, "(")
colName := cts[0] colName := cts[0]
colType = strings.ToUpper(colName) colType = strings.ToUpper(colName)
extra = strings.ToUpper(extra)
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
if colType == schemas.Enum && cts[1][0] == '\'' { // enum switch colType {
options := strings.Split(cts[1][0:idx], ",") case schemas.Enum:
col.EnumOptions = make(map[string]int) if cts[1][0] == '\'' {
for k, v := range options { options := strings.Split(cts[1][0:idx], ",")
v = strings.TrimSpace(v) col.EnumOptions = make(map[string]int)
v = strings.Trim(v, "'") for k, v := range options {
col.EnumOptions[v] = k v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
col.EnumOptions[v] = k
}
} }
} else if colType == schemas.Set && cts[1][0] == '\'' { case schemas.Set:
options := strings.Split(cts[1][0:idx], ",") if cts[1][0] == '\'' {
col.SetOptions = make(map[string]int) options := strings.Split(cts[1][0:idx], ",")
for k, v := range options { col.SetOptions = make(map[string]int)
v = strings.TrimSpace(v) for k, v := range options {
v = strings.Trim(v, "'") v = strings.TrimSpace(v)
col.SetOptions[v] = k v = strings.Trim(v, "'")
col.SetOptions[v] = k
}
} }
} else { default:
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
@ -387,19 +393,30 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
} }
} }
} }
if colType == "FLOAT UNSIGNED" {
colType = "FLOAT" switch colType {
} case "TINYINT UNSIGNED":
if colType == "DOUBLE UNSIGNED" { colType = schemas.TinyIntUnsigned
colType = "DOUBLE" case "SMALLINT UNSIGNED":
colType = schemas.SmallIntUnsigned
case "MEDIUMINT UNSIGNED":
colType = schemas.MediumIntUnsigned
case "INT UNSIGNED":
colType = schemas.IntUnsigned
case "BIGINT UNSIGNED":
colType = schemas.BigIntUnsigned
case "FLOAT UNSIGNED":
colType = schemas.FloatUnsigned
case "DOUBLE UNSIGNED":
colType = schemas.DoubleUnsigned
} }
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := schemas.SqlTypes[colType]; ok { if _, ok := schemas.SqlTypes[colType]; !ok {
col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2}
} else {
return nil, nil, fmt.Errorf("Unknown colType %v", colType) return nil, nil, fmt.Errorf("Unknown colType %v", colType)
} }
col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2}
if colKey == "PRI" { if colKey == "PRI" {
col.IsPrimaryKey = true col.IsPrimaryKey = true
@ -408,7 +425,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName
// col.is // col.is
} }
if extra == "auto_increment" { if extra == "AUTO_INCREMENT" {
col.IsAutoIncrement = true col.IsAutoIncrement = true
} }

View File

@ -61,6 +61,10 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
return nil, err return nil, err
} }
return newEngine(driverName, dataSourceName, dialect, db)
}
func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db *core.DB) (*Engine, error) {
cacherMgr := caches.NewManager() cacherMgr := caches.NewManager()
mapper := names.NewCacheMapper(new(names.SnakeMapper)) mapper := names.NewCacheMapper(new(names.SnakeMapper))
tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr)
@ -88,7 +92,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
engine.SetLogger(log.NewLoggerAdapter(logger)) engine.SetLogger(log.NewLoggerAdapter(logger))
runtime.SetFinalizer(engine, func(engine *Engine) { runtime.SetFinalizer(engine, func(engine *Engine) {
engine.Close() _ = engine.Close()
}) })
return engine, nil return engine, nil
@ -101,6 +105,14 @@ func NewEngineWithParams(driverName string, dataSourceName string, params map[st
return engine, err return engine, err
} }
// NewEngineWithDialectAndDB new a db manager according to the parameter.
// If you do not want to use your own dialect or db, please use NewEngine.
// For creating dialect, you can call dialects.OpenDialect. And, for creating db,
// you can call core.Open or core.FromDB.
func NewEngineWithDialectAndDB(driverName, dataSourceName string, dialect dialects.Dialect, db *core.DB) (*Engine, error) {
return newEngine(driverName, dataSourceName, dialect, db)
}
// EnableSessionID if enable session id // EnableSessionID if enable session id
func (engine *Engine) EnableSessionID(enable bool) { func (engine *Engine) EnableSessionID(enable bool) {
engine.logSessionID = enable engine.logSessionID = enable

View File

@ -69,13 +69,18 @@ func (s *SQLType) IsJson() bool {
} }
var ( var (
Bit = "BIT" Bit = "BIT"
TinyInt = "TINYINT" TinyInt = "TINYINT"
SmallInt = "SMALLINT" TinyIntUnsigned = "TINYINT_UNSIGNED"
MediumInt = "MEDIUMINT" SmallInt = "SMALLINT"
Int = "INT" SmallIntUnsigned = "SMALLINT_UNSIGNED"
Integer = "INTEGER" MediumInt = "MEDIUMINT"
BigInt = "BIGINT" MediumIntUnsigned = "MEDIUMINT_UNSIGNED"
Int = "INT"
IntUnsigned = "INT_UNSIGNED"
Integer = "INTEGER"
BigInt = "BIGINT"
BigIntUnsigned = "BIGINT_UNSIGNED"
Enum = "ENUM" Enum = "ENUM"
Set = "SET" Set = "SET"
@ -107,9 +112,11 @@ var (
Money = "MONEY" Money = "MONEY"
SmallMoney = "SMALLMONEY" SmallMoney = "SMALLMONEY"
Real = "REAL" Real = "REAL"
Float = "FLOAT" Float = "FLOAT"
Double = "DOUBLE" FloatUnsigned = "FLOAT_UNSIGNED"
Double = "DOUBLE"
DoubleUnsigned = "DOUBLE_UNSIGNED"
Binary = "BINARY" Binary = "BINARY"
VarBinary = "VARBINARY" VarBinary = "VARBINARY"
@ -131,13 +138,18 @@ var (
Array = "ARRAY" Array = "ARRAY"
SqlTypes = map[string]int{ SqlTypes = map[string]int{
Bit: NUMERIC_TYPE, Bit: NUMERIC_TYPE,
TinyInt: NUMERIC_TYPE, TinyInt: NUMERIC_TYPE,
SmallInt: NUMERIC_TYPE, TinyIntUnsigned: NUMERIC_TYPE,
MediumInt: NUMERIC_TYPE, SmallInt: NUMERIC_TYPE,
Int: NUMERIC_TYPE, SmallIntUnsigned: NUMERIC_TYPE,
Integer: NUMERIC_TYPE, MediumInt: NUMERIC_TYPE,
BigInt: NUMERIC_TYPE, MediumIntUnsigned: NUMERIC_TYPE,
Int: NUMERIC_TYPE,
IntUnsigned: NUMERIC_TYPE,
Integer: NUMERIC_TYPE,
BigInt: NUMERIC_TYPE,
BigIntUnsigned: NUMERIC_TYPE,
Enum: TEXT_TYPE, Enum: TEXT_TYPE,
Set: TEXT_TYPE, Set: TEXT_TYPE,
@ -165,13 +177,15 @@ var (
SmallDateTime: TIME_TYPE, SmallDateTime: TIME_TYPE,
Year: TIME_TYPE, Year: TIME_TYPE,
Decimal: NUMERIC_TYPE, Decimal: NUMERIC_TYPE,
Numeric: NUMERIC_TYPE, Numeric: NUMERIC_TYPE,
Real: NUMERIC_TYPE, Real: NUMERIC_TYPE,
Float: NUMERIC_TYPE, Float: NUMERIC_TYPE,
Double: NUMERIC_TYPE, FloatUnsigned: NUMERIC_TYPE,
Money: NUMERIC_TYPE, Double: NUMERIC_TYPE,
SmallMoney: NUMERIC_TYPE, DoubleUnsigned: NUMERIC_TYPE,
Money: NUMERIC_TYPE,
SmallMoney: NUMERIC_TYPE,
Binary: BLOB_TYPE, Binary: BLOB_TYPE,
VarBinary: BLOB_TYPE, VarBinary: BLOB_TYPE,

View File

@ -250,11 +250,9 @@ func (session *Session) Sync2(beans ...interface{}) error {
if err != nil { if err != nil {
return err return err
} }
var tbName string tbName := engine.TableName(bean)
if len(session.statement.AltTableName) > 0 { if len(session.statement.AltTableName) > 0 {
tbName = session.statement.AltTableName tbName = session.statement.AltTableName
} else {
tbName = engine.TableName(bean)
} }
tbNameWithSchema := engine.tbNameWithSchema(tbName) tbNameWithSchema := engine.tbNameWithSchema(tbName)

View File

@ -205,8 +205,6 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) {
} }
if j < len(tags)-1 { if j < len(tags)-1 {
ctx.nextTag = tags[j+1] ctx.nextTag = tags[j+1]
} else {
ctx.nextTag = ""
} }
if h, ok := parser.handlers[ctx.tagName]; ok { if h, ok := parser.handlers[ctx.tagName]; ok {

View File

@ -225,37 +225,54 @@ func CommentTagHandler(ctx *Context) error {
// SQLTypeTagHandler describes SQL Type tag handler // SQLTypeTagHandler describes SQL Type tag handler
func SQLTypeTagHandler(ctx *Context) error { func SQLTypeTagHandler(ctx *Context) error {
ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} switch ctx.tagName {
if len(ctx.params) > 0 { case schemas.TinyIntUnsigned:
if ctx.tagName == schemas.Enum { ctx.col.SQLType = schemas.SQLType{Name: "TINYINT UNSIGNED"}
ctx.col.EnumOptions = make(map[string]int) case schemas.SmallIntUnsigned:
for k, v := range ctx.params { ctx.col.SQLType = schemas.SQLType{Name: "SMALLINT UNSIGNED"}
v = strings.TrimSpace(v) case schemas.MediumIntUnsigned:
v = strings.Trim(v, "'") ctx.col.SQLType = schemas.SQLType{Name: "MEDIUMINT UNSIGNED"}
ctx.col.EnumOptions[v] = k case schemas.IntUnsigned:
} ctx.col.SQLType = schemas.SQLType{Name: "INT UNSIGNED"}
} else if ctx.tagName == schemas.Set { case schemas.BigIntUnsigned:
ctx.col.SetOptions = make(map[string]int) ctx.col.SQLType = schemas.SQLType{Name: "BIGINT UNSIGNED"}
for k, v := range ctx.params { case schemas.FloatUnsigned:
v = strings.TrimSpace(v) ctx.col.SQLType = schemas.SQLType{Name: "FLOAT UNSIGNED"}
v = strings.Trim(v, "'") case schemas.DoubleUnsigned:
ctx.col.SetOptions[v] = k ctx.col.SQLType = schemas.SQLType{Name: "DOUBLE UNSIGNED"}
} default:
} else { ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName}
var err error if len(ctx.params) > 0 {
if len(ctx.params) == 2 { if ctx.tagName == schemas.Enum {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) ctx.col.EnumOptions = make(map[string]int)
if err != nil { for k, v := range ctx.params {
return err v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.EnumOptions[v] = k
} }
ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) } else if ctx.tagName == schemas.Set {
if err != nil { ctx.col.SetOptions = make(map[string]int)
return err for k, v := range ctx.params {
v = strings.TrimSpace(v)
v = strings.Trim(v, "'")
ctx.col.SetOptions[v] = k
} }
} else if len(ctx.params) == 1 { } else {
ctx.col.Length, err = strconv.Atoi(ctx.params[0]) var err error
if err != nil { if len(ctx.params) == 2 {
return err ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
ctx.col.Length2, err = strconv.Atoi(ctx.params[1])
if err != nil {
return err
}
} else if len(ctx.params) == 1 {
ctx.col.Length, err = strconv.Atoi(ctx.params[0])
if err != nil {
return err
}
} }
} }
} }