From 81c947b61b7290945679a5a33100297d35aeb514 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Fri, 11 Apr 2014 21:06:11 +0800 Subject: [PATCH] flattened dialects dir and register db dialect for assocaited registered driver --- dialects/mssql.go => mssql_dialect.go | 100 +++++++++---------- dialects/mysql.go => mysql_dialect.go | 74 +++++++------- dialects/oracle.go => oracle_dialect.go | 78 +++++++-------- dialects/postgres.go => postgres_dialect.go | 104 ++++++++++---------- dialects/sqlite3.go => sqlite3_dialect.go | 82 +++++++-------- xorm.go | 23 ++++- 6 files changed, 241 insertions(+), 220 deletions(-) rename dialects/mssql.go => mssql_dialect.go (76%) rename dialects/mysql.go => mysql_dialect.go (81%) rename dialects/oracle.go => oracle_dialect.go (71%) rename dialects/postgres.go => postgres_dialect.go (71%) rename dialects/sqlite3.go => sqlite3_dialect.go (71%) diff --git a/dialects/mssql.go b/mssql_dialect.go similarity index 76% rename from dialects/mssql.go rename to mssql_dialect.go index 10ec3ad7..0b85f6a9 100644 --- a/dialects/mssql.go +++ b/mssql_dialect.go @@ -1,4 +1,4 @@ -package dialects +package xorm import ( "errors" @@ -6,58 +6,58 @@ import ( "strconv" "strings" - . "github.com/go-xorm/core" + "github.com/go-xorm/core" ) -func init() { - RegisterDialect("mssql", &mssql{}) -} +// func init() { +// RegisterDialect("mssql", &mssql{}) +// } type mssql struct { - Base + core.Base } -func (db *mssql) Init(uri *Uri, drivername, dataSourceName string) error { +func (db *mssql) Init(uri *core.Uri, drivername, dataSourceName string) error { return db.Base.Init(db, uri, drivername, dataSourceName) } -func (db *mssql) SqlType(c *Column) string { +func (db *mssql) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { - case Bool: - res = TinyInt - case Serial: + case core.Bool: + res = core.TinyInt + case core.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = Int - case BigSerial: + res = core.Int + case core.BigSerial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = BigInt - case Bytea, Blob, Binary, TinyBlob, MediumBlob, LongBlob: - res = VarBinary + res = core.BigInt + case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob: + res = core.VarBinary if c.Length == 0 { c.Length = 50 } - case TimeStamp: - res = DateTime - case TimeStampz: + case core.TimeStamp: + res = core.DateTime + case core.TimeStampz: res = "DATETIMEOFFSET" c.Length = 7 - case MediumInt: - res = Int - case MediumText, TinyText, LongText: - res = Text - case Double: - res = Real + case core.MediumInt: + res = core.Int + case core.MediumText, core.TinyText, core.LongText: + res = core.Text + case core.Double: + res = core.Real default: res = t } - if res == Int { - return Int + if res == core.Int { + return core.Int } var hasLen1 bool = (c.Length > 0) @@ -118,12 +118,12 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) { return sql, args } -func (db *mssql) GetColumns(tableName string) ([]string, map[string]*Column, error) { +func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{} s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id where a.object_id=object_id('` + tableName + `')` - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, nil, err } @@ -133,7 +133,7 @@ where a.object_id=object_id('` + tableName + `')` if err != nil { return nil, nil, err } - cols := make(map[string]*Column) + cols := make(map[string]*core.Column) colSeq := make([]string, 0) for rows.Next() { var name, ctype, precision, scale string @@ -143,7 +143,7 @@ where a.object_id=object_id('` + tableName + `')` return nil, nil, err } - col := new(Column) + col := new(core.Column) col.Indexes = make(map[string]bool) col.Length = maxLen col.Name = strings.Trim(name, "` ") @@ -151,14 +151,14 @@ where a.object_id=object_id('` + tableName + `')` ct := strings.ToUpper(ctype) switch ct { case "DATETIMEOFFSET": - col.SQLType = SQLType{TimeStampz, 0, 0} + col.SQLType = core.SQLType{core.TimeStampz, 0, 0} case "NVARCHAR": - col.SQLType = SQLType{Varchar, 0, 0} + col.SQLType = core.SQLType{core.Varchar, 0, 0} case "IMAGE": - col.SQLType = SQLType{VarBinary, 0, 0} + col.SQLType = core.SQLType{core.VarBinary, 0, 0} default: - if _, ok := SqlTypes[ct]; ok { - col.SQLType = SQLType{ct, 0, 0} + if _, ok := core.SqlTypes[ct]; ok { + col.SQLType = core.SQLType{ct, 0, 0} } else { return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v", ct, tableName, col.Name)) @@ -180,10 +180,10 @@ where a.object_id=object_id('` + tableName + `')` return colSeq, cols, nil } -func (db *mssql) GetTables() ([]*Table, error) { +func (db *mssql) GetTables() ([]*core.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -193,9 +193,9 @@ func (db *mssql) GetTables() ([]*Table, error) { return nil, err } - tables := make([]*Table, 0) + tables := make([]*core.Table, 0) for rows.Next() { - table := NewEmptyTable() + table := core.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -207,7 +207,7 @@ func (db *mssql) GetTables() ([]*Table, error) { return tables, nil } -func (db *mssql) GetIndexes(tableName string) (map[string]*Index, error) { +func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := `SELECT IXS.NAME AS [INDEX_NAME], @@ -223,7 +223,7 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -233,7 +233,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return nil, err } - indexes := make(map[string]*Index, 0) + indexes := make(map[string]*core.Index, 0) for rows.Next() { var indexType int var indexName, colName, isUnique string @@ -249,9 +249,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? } if i { - indexType = UniqueType + indexType = core.UniqueType } else { - indexType = IndexType + indexType = core.IndexType } colName = strings.Trim(colName, "` ") @@ -260,10 +260,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? indexName = indexName[5+len(tableName) : len(indexName)] } - var index *Index + var index *core.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(Index) + index = new(core.Index) index.Type = indexType index.Name = indexName indexes[indexName] = index @@ -273,7 +273,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? return indexes, nil } -func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset string) string { +func (db *mssql) CreateTablSql(table *core.Table, tableName, storeEngine, charset string) string { var sql string if tableName == "" { tableName = table.Name @@ -307,6 +307,6 @@ func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset str return sql } -func (db *mssql) Filters() []Filter { - return []Filter{&IdFilter{}, &QuoteFilter{}} +func (db *mssql) Filters() []core.Filter { + return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}} } diff --git a/dialects/mysql.go b/mysql_dialect.go similarity index 81% rename from dialects/mysql.go rename to mysql_dialect.go index ffba757c..52a8a15e 100644 --- a/dialects/mysql.go +++ b/mysql_dialect.go @@ -1,4 +1,4 @@ -package dialects +package xorm import ( "crypto/tls" @@ -8,15 +8,15 @@ import ( "strings" "time" - . "github.com/go-xorm/core" + "github.com/go-xorm/core" ) -func init() { - RegisterDialect("mysql", &mysql{}) -} +// func init() { +// RegisterDialect("mysql", &mysql{}) +// } type mysql struct { - Base + core.Base net string addr string params map[string]string @@ -28,30 +28,30 @@ type mysql struct { clientFoundRows bool } -func (db *mysql) Init(uri *Uri, drivername, dataSourceName string) error { +func (db *mysql) Init(uri *core.Uri, drivername, dataSourceName string) error { return db.Base.Init(db, uri, drivername, dataSourceName) } -func (db *mysql) SqlType(c *Column) string { +func (db *mysql) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { - case Bool: - res = TinyInt + case core.Bool: + res = core.TinyInt c.Length = 1 - case Serial: + case core.Serial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = Int - case BigSerial: + res = core.Int + case core.BigSerial: c.IsAutoIncrement = true c.IsPrimaryKey = true c.Nullable = false - res = BigInt - case Bytea: - res = Blob - case TimeStampz: - res = Char + res = core.BigInt + case core.Bytea: + res = core.Blob + case core.TimeStampz: + res = core.Char c.Length = 64 default: res = t @@ -110,11 +110,11 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { return sql, args } -func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { +func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{db.DbName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, nil, err } @@ -123,10 +123,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err if err != nil { return nil, nil, err } - cols := make(map[string]*Column) + cols := make(map[string]*core.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(Column) + col := new(core.Column) col.Indexes = make(map[string]bool) var columnName, isNullable, colType, colKey, extra string @@ -164,8 +164,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err colType = strings.ToUpper(colName) col.Length = len1 col.Length2 = len2 - if _, ok := SqlTypes[colType]; ok { - col.SQLType = SQLType{colType, len1, len2} + if _, ok := core.SqlTypes[colType]; ok { + col.SQLType = core.SQLType{colType, len1, len2} } else { return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) } @@ -192,10 +192,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err return colSeq, cols, nil } -func (db *mysql) GetTables() ([]*Table, error) { +func (db *mysql) GetTables() ([]*core.Table, error) { args := []interface{}{db.DbName} s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -205,9 +205,9 @@ func (db *mysql) GetTables() ([]*Table, error) { return nil, err } - tables := make([]*Table, 0) + tables := make([]*core.Table, 0) for rows.Next() { - table := NewEmptyTable() + table := core.NewEmptyTable() var name, engine, tableRows string var autoIncr *string err = rows.Scan(&name, &engine, &tableRows, &autoIncr) @@ -221,10 +221,10 @@ func (db *mysql) GetTables() ([]*Table, error) { return tables, nil } -func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { +func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{db.DbName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -234,7 +234,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { return nil, err } - indexes := make(map[string]*Index, 0) + indexes := make(map[string]*core.Index, 0) for rows.Next() { var indexType int var indexName, colName, nonUnique string @@ -248,9 +248,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { } if "YES" == nonUnique || nonUnique == "1" { - indexType = IndexType + indexType = core.IndexType } else { - indexType = UniqueType + indexType = core.UniqueType } colName = strings.Trim(colName, "` ") @@ -259,10 +259,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { indexName = indexName[5+len(tableName) : len(indexName)] } - var index *Index + var index *core.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(Index) + index = new(core.Index) index.Type = indexType index.Name = indexName indexes[indexName] = index @@ -272,6 +272,6 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { return indexes, nil } -func (db *mysql) Filters() []Filter { - return []Filter{&IdFilter{}} +func (db *mysql) Filters() []core.Filter { + return []core.Filter{&core.IdFilter{}} } diff --git a/dialects/oracle.go b/oracle_dialect.go similarity index 71% rename from dialects/oracle.go rename to oracle_dialect.go index b86a74fc..febd318e 100644 --- a/dialects/oracle.go +++ b/oracle_dialect.go @@ -1,4 +1,4 @@ -package dialects +package xorm import ( "errors" @@ -6,37 +6,37 @@ import ( "strconv" "strings" - . "github.com/go-xorm/core" + "github.com/go-xorm/core" ) -func init() { - RegisterDialect("oracle", &oracle{}) -} +// func init() { +// RegisterDialect("oracle", &oracle{}) +// } type oracle struct { - Base + core.Base } -func (db *oracle) Init(uri *Uri, drivername, dataSourceName string) error { +func (db *oracle) Init(uri *core.Uri, drivername, dataSourceName string) error { return db.Base.Init(db, uri, drivername, dataSourceName) } -func (db *oracle) SqlType(c *Column) string { +func (db *oracle) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial: + case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial: return "NUMBER" - case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea: - return Blob - case Time, DateTime, TimeStamp: - res = TimeStamp - case TimeStampz: + case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea: + return core.Blob + case core.Time, core.DateTime, core.TimeStamp: + res = core.TimeStamp + case core.TimeStampz: res = "TIMESTAMP WITH TIME ZONE" - case Float, Double, Numeric, Decimal: + case core.Float, core.Double, core.Numeric, core.Decimal: res = "NUMBER" - case Text, MediumText, LongText: + case core.Text, core.MediumText, core.LongText: res = "CLOB" - case Char, Varchar, TinyText: + case core.Char, core.Varchar, core.TinyText: return "VARCHAR2" default: res = t @@ -93,12 +93,12 @@ func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface " AND column_name = ?", args } -func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) { +func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{strings.ToUpper(tableName)} s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, nil, err } @@ -109,10 +109,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er } defer rows.Close() - cols := make(map[string]*Column) + cols := make(map[string]*core.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(Column) + col := new(core.Column) col.Indexes = make(map[string]bool) var colName, colDefault, nullable, dataType, dataPrecision, dataScale string @@ -135,13 +135,13 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er switch dataType { case "VARCHAR2": - col.SQLType = SQLType{Varchar, 0, 0} + col.SQLType = core.SQLType{core.Varchar, 0, 0} case "TIMESTAMP WITH TIME ZONE": - col.SQLType = SQLType{TimeStampz, 0, 0} + col.SQLType = core.SQLType{core.TimeStampz, 0, 0} default: - col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0} + col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0} } - if _, ok := SqlTypes[col.SQLType.Name]; !ok { + if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType)) } @@ -163,10 +163,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er return colSeq, cols, nil } -func (db *oracle) GetTables() ([]*Table, error) { +func (db *oracle) GetTables() ([]*core.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -176,9 +176,9 @@ func (db *oracle) GetTables() ([]*Table, error) { return nil, err } - tables := make([]*Table, 0) + tables := make([]*core.Table, 0) for rows.Next() { - table := NewEmptyTable() + table := core.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { return nil, err @@ -189,12 +189,12 @@ func (db *oracle) GetTables() ([]*Table, error) { return tables, nil } -func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { +func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -205,7 +205,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { } defer rows.Close() - indexes := make(map[string]*Index, 0) + indexes := make(map[string]*core.Index, 0) for rows.Next() { var indexType int var indexName, colName, uniqueness string @@ -218,15 +218,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { indexName = strings.Trim(indexName, `" `) if uniqueness == "UNIQUE" { - indexType = UniqueType + indexType = core.UniqueType } else { - indexType = IndexType + indexType = core.IndexType } - var index *Index + var index *core.Index var ok bool if index, ok = indexes[indexName]; !ok { - index = new(Index) + index = new(core.Index) index.Type = indexType index.Name = indexName indexes[indexName] = index @@ -240,7 +240,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { type OracleSeqFilter struct { } -func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string { +func (s *OracleSeqFilter) Do(sql string, dialect core.Dialect, table *core.Table) string { counts := strings.Count(sql, "?") for i := 1; i <= counts; i++ { newstr := ":" + fmt.Sprintf("%v", i) @@ -249,6 +249,6 @@ func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string { return sql } -func (db *oracle) Filters() []Filter { - return []Filter{&QuoteFilter{}, &OracleSeqFilter{}, &IdFilter{}} +func (db *oracle) Filters() []core.Filter { + return []core.Filter{&core.QuoteFilter{}, &OracleSeqFilter{}, &core.IdFilter{}} } diff --git a/dialects/postgres.go b/postgres_dialect.go similarity index 71% rename from dialects/postgres.go rename to postgres_dialect.go index d365732e..58d39f9d 100644 --- a/dialects/postgres.go +++ b/postgres_dialect.go @@ -1,4 +1,4 @@ -package dialects +package xorm import ( "errors" @@ -6,53 +6,53 @@ import ( "strconv" "strings" - . "github.com/go-xorm/core" + "github.com/go-xorm/core" ) -func init() { - RegisterDialect("postgres", &postgres{}) -} +// func init() { +// RegisterDialect("postgres", &postgres{}) +// } type postgres struct { - Base + core.Base } -func (db *postgres) Init(uri *Uri, drivername, dataSourceName string) error { +func (db *postgres) Init(uri *core.Uri, drivername, dataSourceName string) error { return db.Base.Init(db, uri, drivername, dataSourceName) } -func (db *postgres) SqlType(c *Column) string { +func (db *postgres) SqlType(c *core.Column) string { var res string switch t := c.SQLType.Name; t { - case TinyInt: - res = SmallInt + case core.TinyInt: + res = core.SmallInt return res - case MediumInt, Int, Integer: + case core.MediumInt, core.Int, core.Integer: if c.IsAutoIncrement { - return Serial + return core.Serial } - return Integer - case Serial, BigSerial: + return core.Integer + case core.Serial, core.BigSerial: c.IsAutoIncrement = true c.Nullable = false res = t - case Binary, VarBinary: - return Bytea - case DateTime: - res = TimeStamp - case TimeStampz: + case core.Binary, core.VarBinary: + return core.Bytea + case core.DateTime: + res = core.TimeStamp + case core.TimeStampz: return "timestamp with time zone" - case Float: - res = Real - case TinyText, MediumText, LongText: - res = Text - case Blob, TinyBlob, MediumBlob, LongBlob: - return Bytea - case Double: + case core.Float: + res = core.Real + case core.TinyText, core.MediumText, core.LongText: + res = core.Text + case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob: + return core.Bytea + case core.Double: return "DOUBLE PRECISION" default: if c.IsAutoIncrement { - return Serial + return core.Serial } res = t } @@ -108,11 +108,11 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa " AND column_name = ?", args } -func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { +func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, nil, err } @@ -121,11 +121,11 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, if err != nil { return nil, nil, err } - cols := make(map[string]*Column) + cols := make(map[string]*core.Column) colSeq := make([]string, 0) for rows.Next() { - col := new(Column) + col := new(core.Column) col.Indexes = make(map[string]bool) var colName, isNullable, dataType string @@ -161,21 +161,21 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, switch dataType { case "character varying", "character": - col.SQLType = SQLType{Varchar, 0, 0} + col.SQLType = core.SQLType{core.Varchar, 0, 0} case "timestamp without time zone": - col.SQLType = SQLType{DateTime, 0, 0} + col.SQLType = core.SQLType{core.DateTime, 0, 0} case "timestamp with time zone": - col.SQLType = SQLType{TimeStampz, 0, 0} + col.SQLType = core.SQLType{core.TimeStampz, 0, 0} case "double precision": - col.SQLType = SQLType{Double, 0, 0} + col.SQLType = core.SQLType{core.Double, 0, 0} case "boolean": - col.SQLType = SQLType{Bool, 0, 0} + col.SQLType = core.SQLType{core.Bool, 0, 0} case "time without time zone": - col.SQLType = SQLType{Time, 0, 0} + col.SQLType = core.SQLType{core.Time, 0, 0} default: - col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0} + col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0} } - if _, ok := SqlTypes[col.SQLType.Name]; !ok { + if _, ok := core.SqlTypes[col.SQLType.Name]; !ok { return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType)) } @@ -197,10 +197,10 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, return colSeq, cols, nil } -func (db *postgres) GetTables() ([]*Table, error) { +func (db *postgres) GetTables() ([]*core.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables where schemaname = 'public'" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -210,9 +210,9 @@ func (db *postgres) GetTables() ([]*Table, error) { return nil, err } - tables := make([]*Table, 0) + tables := make([]*core.Table, 0) for rows.Next() { - table := NewEmptyTable() + table := core.NewEmptyTable() var name string err = rows.Scan(&name) if err != nil { @@ -224,11 +224,11 @@ func (db *postgres) GetTables() ([]*Table, error) { return tables, nil } -func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { +func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := "SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -238,7 +238,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { return nil, err } - indexes := make(map[string]*Index, 0) + indexes := make(map[string]*core.Index, 0) for rows.Next() { var indexType int var indexName, indexdef string @@ -250,9 +250,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { indexName = strings.Trim(indexName, `" `) if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { - indexType = UniqueType + indexType = core.UniqueType } else { - indexType = IndexType + indexType = core.IndexType } cs := strings.Split(indexdef, "(") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") @@ -267,7 +267,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { } } - index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} for _, colName := range colNames { index.Cols = append(index.Cols, strings.Trim(colName, `" `)) } @@ -280,7 +280,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { type PgSeqFilter struct { } -func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string { +func (s *PgSeqFilter) Do(sql string, dialect core.Dialect, table *core.Table) string { segs := strings.Split(sql, "?") size := len(segs) res := "" @@ -293,6 +293,6 @@ func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string { return res } -func (db *postgres) Filters() []Filter { - return []Filter{&IdFilter{}, &QuoteFilter{}, &PgSeqFilter{}} +func (db *postgres) Filters() []core.Filter { + return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &PgSeqFilter{}} } diff --git a/dialects/sqlite3.go b/sqlite3_dialect.go similarity index 71% rename from dialects/sqlite3.go rename to sqlite3_dialect.go index aa829424..d5be5c72 100644 --- a/dialects/sqlite3.go +++ b/sqlite3_dialect.go @@ -1,44 +1,44 @@ -package dialects +package xorm import ( "strings" - . "github.com/go-xorm/core" + "github.com/go-xorm/core" ) -func init() { - RegisterDialect("sqlite3", &sqlite3{}) -} +// func init() { +// RegisterDialect("sqlite3", &sqlite3{}) +// } type sqlite3 struct { - Base + core.Base } -func (db *sqlite3) Init(uri *Uri, drivername, dataSourceName string) error { +func (db *sqlite3) Init(uri *core.Uri, drivername, dataSourceName string) error { return db.Base.Init(db, uri, drivername, dataSourceName) } -func (db *sqlite3) SqlType(c *Column) string { +func (db *sqlite3) SqlType(c *core.Column) string { switch t := c.SQLType.Name; t { - case Date, DateTime, TimeStamp, Time: - return Numeric - case TimeStampz: - return Text - case Char, Varchar, TinyText, Text, MediumText, LongText: - return Text - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: - return Integer - case Float, Double, Real: - return Real - case Decimal, Numeric: - return Numeric - case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: - return Blob - case Serial, BigSerial: + case core.Date, core.DateTime, core.TimeStamp, core.Time: + return core.Numeric + case core.TimeStampz: + return core.Text + case core.Char, core.Varchar, core.TinyText, core.Text, core.MediumText, core.LongText: + return core.Text + case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool: + return core.Integer + case core.Float, core.Double, core.Real: + return core.Real + case core.Decimal, core.Numeric: + return core.Numeric + case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary: + return core.Blob + case core.Serial, core.BigSerial: c.IsPrimaryKey = true c.IsAutoIncrement = true c.Nullable = false - return Integer + return core.Integer default: return t } @@ -84,10 +84,10 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac return sql, args } -func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { +func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, nil, err } @@ -110,11 +110,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e nStart := strings.Index(name, "(") nEnd := strings.Index(name, ")") colCreates := strings.Split(name[nStart+1:nEnd], ",") - cols := make(map[string]*Column) + cols := make(map[string]*core.Column) colSeq := make([]string, 0) for _, colStr := range colCreates { fields := strings.Fields(strings.TrimSpace(colStr)) - col := new(Column) + col := new(core.Column) col.Indexes = make(map[string]bool) col.Nullable = true for idx, field := range fields { @@ -122,7 +122,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e col.Name = strings.Trim(field, "`[] ") continue } else if idx == 1 { - col.SQLType = SQLType{field, 0, 0} + col.SQLType = core.SQLType{field, 0, 0} } switch field { case "PRIMARY": @@ -143,11 +143,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e return colSeq, cols, nil } -func (db *sqlite3) GetTables() ([]*Table, error) { +func (db *sqlite3) GetTables() ([]*core.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -158,9 +158,9 @@ func (db *sqlite3) GetTables() ([]*Table, error) { } defer rows.Close() - tables := make([]*Table, 0) + tables := make([]*core.Table, 0) for rows.Next() { - table := NewEmptyTable() + table := core.NewEmptyTable() err = rows.Scan(&table.Name) if err != nil { return nil, err @@ -173,10 +173,10 @@ func (db *sqlite3) GetTables() ([]*Table, error) { return tables, nil } -func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { +func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - cnn, err := Open(db.DriverName(), db.DataSourceName()) + cnn, err := core.Open(db.DriverName(), db.DataSourceName()) if err != nil { return nil, err } @@ -187,7 +187,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { } defer rows.Close() - indexes := make(map[string]*Index, 0) + indexes := make(map[string]*core.Index, 0) for rows.Next() { var sql string err = rows.Scan(&sql) @@ -199,7 +199,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { continue } - index := new(Index) + index := new(core.Index) nNStart := strings.Index(sql, "INDEX") nNEnd := strings.Index(sql, "ON") if nNStart == -1 || nNEnd == -1 { @@ -215,9 +215,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { } if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { - index.Type = UniqueType + index.Type = core.UniqueType } else { - index.Type = IndexType + index.Type = core.IndexType } nStart := strings.Index(sql, "(") @@ -234,6 +234,6 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { return indexes, nil } -func (db *sqlite3) Filters() []Filter { - return []Filter{&IdFilter{}} +func (db *sqlite3) Filters() []core.Filter { + return []core.Filter{&core.IdFilter{}} } diff --git a/xorm.go b/xorm.go index ccd33091..705b61a0 100644 --- a/xorm.go +++ b/xorm.go @@ -1,6 +1,7 @@ package xorm import ( + "database/sql" "errors" "fmt" "os" @@ -11,7 +12,6 @@ import ( "github.com/go-xorm/core" "github.com/go-xorm/xorm/caches" - _ "github.com/go-xorm/xorm/dialects" _ "github.com/go-xorm/xorm/drivers" ) @@ -19,6 +19,27 @@ const ( Version string = "0.4" ) +func init() { + provided_dialects := map[string]struct { + dbType core.DbType + get func() core.Dialect + }{ + "odbc": {"mssql", func() core.Dialect { return &mssql{} }}, + "mysql": {"mysql", func() core.Dialect { return &mysql{} }}, + "mymysql": {"mysql", func() core.Dialect { return &mysql{} }}, + "oci8": {"oracle", func() core.Dialect { return &oracle{} }}, + "postgres": {"postgres", func() core.Dialect { return &postgres{} }}, + "sqlite3": {"sqlite3", func() core.Dialect { return &sqlite3{} }}, + } + + for k, v := range provided_dialects { + _, err := sql.Open(string(k), "") + if err == nil { + core.RegisterDialect(v.dbType, v.get()) + } + } +} + func close(engine *Engine) { engine.Close() }