From 50deb795816c629d7e7ecdb52a4fcd9eb15a1ec1 Mon Sep 17 00:00:00 2001 From: Jerry Date: Sat, 10 Oct 2020 01:18:03 +0000 Subject: [PATCH] add column type unsigned --- dialects/mysql.go | 65 +++++++++++++++++++++++++--------------- engine.go | 14 ++++++++- schemas/type.go | 62 ++++++++++++++++++++++++--------------- session_schema.go | 4 +-- tags/parser.go | 2 -- tags/tag.go | 75 +++++++++++++++++++++++++++++------------------ 6 files changed, 139 insertions(+), 83 deletions(-) diff --git a/dialects/mysql.go b/dialects/mysql.go index 32e18a17..2a56fb81 100644 --- a/dialects/mysql.go +++ b/dialects/mysql.go @@ -354,26 +354,32 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName cts := strings.Split(colType, "(") colName := cts[0] colType = strings.ToUpper(colName) + extra = strings.ToUpper(extra) var len1, len2 int if len(cts) == 2 { idx := strings.Index(cts[1], ")") - if colType == schemas.Enum && cts[1][0] == '\'' { // enum - options := strings.Split(cts[1][0:idx], ",") - col.EnumOptions = make(map[string]int) - for k, v := range options { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - col.EnumOptions[v] = k + switch colType { + case schemas.Enum: + if cts[1][0] == '\'' { + options := strings.Split(cts[1][0:idx], ",") + col.EnumOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.EnumOptions[v] = k + } } - } else if colType == schemas.Set && cts[1][0] == '\'' { - options := strings.Split(cts[1][0:idx], ",") - col.SetOptions = make(map[string]int) - for k, v := range options { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - col.SetOptions[v] = k + case schemas.Set: + if cts[1][0] == '\'' { + options := strings.Split(cts[1][0:idx], ",") + col.SetOptions = make(map[string]int) + for k, v := range options { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + col.SetOptions[v] = k + } } - } else { + default: lens := strings.Split(cts[1][0:idx], ",") len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) if err != nil { @@ -387,19 +393,30 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName } } } - if colType == "FLOAT UNSIGNED" { - colType = "FLOAT" - } - if colType == "DOUBLE UNSIGNED" { - colType = "DOUBLE" + + switch colType { + case "TINYINT UNSIGNED": + colType = schemas.TinyIntUnsigned + case "SMALLINT UNSIGNED": + colType = schemas.SmallIntUnsigned + case "MEDIUMINT UNSIGNED": + colType = schemas.MediumIntUnsigned + case "INT UNSIGNED": + colType = schemas.IntUnsigned + case "BIGINT UNSIGNED": + colType = schemas.BigIntUnsigned + case "FLOAT UNSIGNED": + colType = schemas.FloatUnsigned + case "DOUBLE UNSIGNED": + colType = schemas.DoubleUnsigned } + col.Length = len1 col.Length2 = len2 - if _, ok := schemas.SqlTypes[colType]; ok { - col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} - } else { + if _, ok := schemas.SqlTypes[colType]; !ok { return nil, nil, fmt.Errorf("Unknown colType %v", colType) } + col.SQLType = schemas.SQLType{Name: colType, DefaultLength: len1, DefaultLength2: len2} if colKey == "PRI" { col.IsPrimaryKey = true @@ -408,7 +425,7 @@ func (db *mysql) GetColumns(queryer core.Queryer, ctx context.Context, tableName // col.is } - if extra == "auto_increment" { + if extra == "AUTO_INCREMENT" { col.IsAutoIncrement = true } diff --git a/engine.go b/engine.go index 4159a7b2..6c894e74 100644 --- a/engine.go +++ b/engine.go @@ -61,6 +61,10 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, err } + return newEngine(driverName, dataSourceName, dialect, db) +} + +func newEngine(driverName, dataSourceName string, dialect dialects.Dialect, db *core.DB) (*Engine, error) { cacherMgr := caches.NewManager() mapper := names.NewCacheMapper(new(names.SnakeMapper)) tagParser := tags.NewParser("xorm", dialect, mapper, mapper, cacherMgr) @@ -88,7 +92,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { engine.SetLogger(log.NewLoggerAdapter(logger)) runtime.SetFinalizer(engine, func(engine *Engine) { - engine.Close() + _ = engine.Close() }) return engine, nil @@ -101,6 +105,14 @@ func NewEngineWithParams(driverName string, dataSourceName string, params map[st return engine, err } +// NewEngineWithDialectAndDB new a db manager according to the parameter. +// If you do not want to use your own dialect or db, please use NewEngine. +// For creating dialect, you can call dialects.OpenDialect. And, for creating db, +// you can call core.Open or core.FromDB. +func NewEngineWithDialectAndDB(driverName, dataSourceName string, dialect dialects.Dialect, db *core.DB) (*Engine, error) { + return newEngine(driverName, dataSourceName, dialect, db) +} + // EnableSessionID if enable session id func (engine *Engine) EnableSessionID(enable bool) { engine.logSessionID = enable diff --git a/schemas/type.go b/schemas/type.go index 89459a4d..3bd752c9 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -69,13 +69,18 @@ func (s *SQLType) IsJson() bool { } var ( - Bit = "BIT" - TinyInt = "TINYINT" - SmallInt = "SMALLINT" - MediumInt = "MEDIUMINT" - Int = "INT" - Integer = "INTEGER" - BigInt = "BIGINT" + Bit = "BIT" + TinyInt = "TINYINT" + TinyIntUnsigned = "TINYINT_UNSIGNED" + SmallInt = "SMALLINT" + SmallIntUnsigned = "SMALLINT_UNSIGNED" + MediumInt = "MEDIUMINT" + MediumIntUnsigned = "MEDIUMINT_UNSIGNED" + Int = "INT" + IntUnsigned = "INT_UNSIGNED" + Integer = "INTEGER" + BigInt = "BIGINT" + BigIntUnsigned = "BIGINT_UNSIGNED" Enum = "ENUM" Set = "SET" @@ -107,9 +112,11 @@ var ( Money = "MONEY" SmallMoney = "SMALLMONEY" - Real = "REAL" - Float = "FLOAT" - Double = "DOUBLE" + Real = "REAL" + Float = "FLOAT" + FloatUnsigned = "FLOAT_UNSIGNED" + Double = "DOUBLE" + DoubleUnsigned = "DOUBLE_UNSIGNED" Binary = "BINARY" VarBinary = "VARBINARY" @@ -131,13 +138,18 @@ var ( Array = "ARRAY" SqlTypes = map[string]int{ - Bit: NUMERIC_TYPE, - TinyInt: NUMERIC_TYPE, - SmallInt: NUMERIC_TYPE, - MediumInt: NUMERIC_TYPE, - Int: NUMERIC_TYPE, - Integer: NUMERIC_TYPE, - BigInt: NUMERIC_TYPE, + Bit: NUMERIC_TYPE, + TinyInt: NUMERIC_TYPE, + TinyIntUnsigned: NUMERIC_TYPE, + SmallInt: NUMERIC_TYPE, + SmallIntUnsigned: NUMERIC_TYPE, + MediumInt: NUMERIC_TYPE, + MediumIntUnsigned: NUMERIC_TYPE, + Int: NUMERIC_TYPE, + IntUnsigned: NUMERIC_TYPE, + Integer: NUMERIC_TYPE, + BigInt: NUMERIC_TYPE, + BigIntUnsigned: NUMERIC_TYPE, Enum: TEXT_TYPE, Set: TEXT_TYPE, @@ -165,13 +177,15 @@ var ( SmallDateTime: TIME_TYPE, Year: TIME_TYPE, - Decimal: NUMERIC_TYPE, - Numeric: NUMERIC_TYPE, - Real: NUMERIC_TYPE, - Float: NUMERIC_TYPE, - Double: NUMERIC_TYPE, - Money: NUMERIC_TYPE, - SmallMoney: NUMERIC_TYPE, + Decimal: NUMERIC_TYPE, + Numeric: NUMERIC_TYPE, + Real: NUMERIC_TYPE, + Float: NUMERIC_TYPE, + FloatUnsigned: NUMERIC_TYPE, + Double: NUMERIC_TYPE, + DoubleUnsigned: NUMERIC_TYPE, + Money: NUMERIC_TYPE, + SmallMoney: NUMERIC_TYPE, Binary: BLOB_TYPE, VarBinary: BLOB_TYPE, diff --git a/session_schema.go b/session_schema.go index 9ccf8abe..855ab736 100644 --- a/session_schema.go +++ b/session_schema.go @@ -250,11 +250,9 @@ func (session *Session) Sync2(beans ...interface{}) error { if err != nil { return err } - var tbName string + tbName := engine.TableName(bean) if len(session.statement.AltTableName) > 0 { tbName = session.statement.AltTableName - } else { - tbName = engine.TableName(bean) } tbNameWithSchema := engine.tbNameWithSchema(tbName) diff --git a/tags/parser.go b/tags/parser.go index add30a13..71b2336c 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -205,8 +205,6 @@ func (parser *Parser) Parse(v reflect.Value) (*schemas.Table, error) { } if j < len(tags)-1 { ctx.nextTag = tags[j+1] - } else { - ctx.nextTag = "" } if h, ok := parser.handlers[ctx.tagName]; ok { diff --git a/tags/tag.go b/tags/tag.go index ee3f1e82..db27e4c0 100644 --- a/tags/tag.go +++ b/tags/tag.go @@ -225,37 +225,54 @@ func CommentTagHandler(ctx *Context) error { // SQLTypeTagHandler describes SQL Type tag handler func SQLTypeTagHandler(ctx *Context) error { - ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} - if len(ctx.params) > 0 { - if ctx.tagName == schemas.Enum { - ctx.col.EnumOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.EnumOptions[v] = k - } - } else if ctx.tagName == schemas.Set { - ctx.col.SetOptions = make(map[string]int) - for k, v := range ctx.params { - v = strings.TrimSpace(v) - v = strings.Trim(v, "'") - ctx.col.SetOptions[v] = k - } - } else { - var err error - if len(ctx.params) == 2 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err + switch ctx.tagName { + case schemas.TinyIntUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "TINYINT UNSIGNED"} + case schemas.SmallIntUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "SMALLINT UNSIGNED"} + case schemas.MediumIntUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "MEDIUMINT UNSIGNED"} + case schemas.IntUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "INT UNSIGNED"} + case schemas.BigIntUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "BIGINT UNSIGNED"} + case schemas.FloatUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "FLOAT UNSIGNED"} + case schemas.DoubleUnsigned: + ctx.col.SQLType = schemas.SQLType{Name: "DOUBLE UNSIGNED"} + default: + ctx.col.SQLType = schemas.SQLType{Name: ctx.tagName} + if len(ctx.params) > 0 { + if ctx.tagName == schemas.Enum { + ctx.col.EnumOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.EnumOptions[v] = k } - ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) - if err != nil { - return err + } else if ctx.tagName == schemas.Set { + ctx.col.SetOptions = make(map[string]int) + for k, v := range ctx.params { + v = strings.TrimSpace(v) + v = strings.Trim(v, "'") + ctx.col.SetOptions[v] = k } - } else if len(ctx.params) == 1 { - ctx.col.Length, err = strconv.Atoi(ctx.params[0]) - if err != nil { - return err + } else { + var err error + if len(ctx.params) == 2 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err + } + ctx.col.Length2, err = strconv.Atoi(ctx.params[1]) + if err != nil { + return err + } + } else if len(ctx.params) == 1 { + ctx.col.Length, err = strconv.Atoi(ctx.params[0]) + if err != nil { + return err + } } } }