flattened dialects dir and register db dialect for assocaited registered driver

This commit is contained in:
Nash Tsai 2014-04-11 21:06:11 +08:00
parent a0919b5371
commit 81c947b61b
6 changed files with 241 additions and 220 deletions

View File

@ -1,4 +1,4 @@
package dialects
package xorm
import (
"errors"
@ -6,58 +6,58 @@ import (
"strconv"
"strings"
. "github.com/go-xorm/core"
"github.com/go-xorm/core"
)
func init() {
RegisterDialect("mssql", &mssql{})
}
// func init() {
// RegisterDialect("mssql", &mssql{})
// }
type mssql struct {
Base
core.Base
}
func (db *mssql) Init(uri *Uri, drivername, dataSourceName string) error {
func (db *mssql) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
}
func (db *mssql) SqlType(c *Column) string {
func (db *mssql) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
case Bool:
res = TinyInt
case Serial:
case core.Bool:
res = core.TinyInt
case core.Serial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = Int
case BigSerial:
res = core.Int
case core.BigSerial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = BigInt
case Bytea, Blob, Binary, TinyBlob, MediumBlob, LongBlob:
res = VarBinary
res = core.BigInt
case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob:
res = core.VarBinary
if c.Length == 0 {
c.Length = 50
}
case TimeStamp:
res = DateTime
case TimeStampz:
case core.TimeStamp:
res = core.DateTime
case core.TimeStampz:
res = "DATETIMEOFFSET"
c.Length = 7
case MediumInt:
res = Int
case MediumText, TinyText, LongText:
res = Text
case Double:
res = Real
case core.MediumInt:
res = core.Int
case core.MediumText, core.TinyText, core.LongText:
res = core.Text
case core.Double:
res = core.Real
default:
res = t
}
if res == Int {
return Int
if res == core.Int {
return core.Int
}
var hasLen1 bool = (c.Length > 0)
@ -118,12 +118,12 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) {
return sql, args
}
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{}
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale
from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id
where a.object_id=object_id('` + tableName + `')`
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, nil, err
}
@ -133,7 +133,7 @@ where a.object_id=object_id('` + tableName + `')`
if err != nil {
return nil, nil, err
}
cols := make(map[string]*Column)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)
for rows.Next() {
var name, ctype, precision, scale string
@ -143,7 +143,7 @@ where a.object_id=object_id('` + tableName + `')`
return nil, nil, err
}
col := new(Column)
col := new(core.Column)
col.Indexes = make(map[string]bool)
col.Length = maxLen
col.Name = strings.Trim(name, "` ")
@ -151,14 +151,14 @@ where a.object_id=object_id('` + tableName + `')`
ct := strings.ToUpper(ctype)
switch ct {
case "DATETIMEOFFSET":
col.SQLType = SQLType{TimeStampz, 0, 0}
col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
case "NVARCHAR":
col.SQLType = SQLType{Varchar, 0, 0}
col.SQLType = core.SQLType{core.Varchar, 0, 0}
case "IMAGE":
col.SQLType = SQLType{VarBinary, 0, 0}
col.SQLType = core.SQLType{core.VarBinary, 0, 0}
default:
if _, ok := SqlTypes[ct]; ok {
col.SQLType = SQLType{ct, 0, 0}
if _, ok := core.SqlTypes[ct]; ok {
col.SQLType = core.SQLType{ct, 0, 0}
} else {
return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v",
ct, tableName, col.Name))
@ -180,10 +180,10 @@ where a.object_id=object_id('` + tableName + `')`
return colSeq, cols, nil
}
func (db *mssql) GetTables() ([]*Table, error) {
func (db *mssql) GetTables() ([]*core.Table, error) {
args := []interface{}{}
s := `select name from sysobjects where xtype ='U'`
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -193,9 +193,9 @@ func (db *mssql) GetTables() ([]*Table, error) {
return nil, err
}
tables := make([]*Table, 0)
tables := make([]*core.Table, 0)
for rows.Next() {
table := NewEmptyTable()
table := core.NewEmptyTable()
var name string
err = rows.Scan(&name)
if err != nil {
@ -207,7 +207,7 @@ func (db *mssql) GetTables() ([]*Table, error) {
return tables, nil
}
func (db *mssql) GetIndexes(tableName string) (map[string]*Index, error) {
func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName}
s := `SELECT
IXS.NAME AS [INDEX_NAME],
@ -223,7 +223,7 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID
AND IXCS.COLUMN_ID=C.COLUMN_ID
WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
`
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -233,7 +233,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return nil, err
}
indexes := make(map[string]*Index, 0)
indexes := make(map[string]*core.Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, isUnique string
@ -249,9 +249,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
}
if i {
indexType = UniqueType
indexType = core.UniqueType
} else {
indexType = IndexType
indexType = core.IndexType
}
colName = strings.Trim(colName, "` ")
@ -260,10 +260,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
indexName = indexName[5+len(tableName) : len(indexName)]
}
var index *Index
var index *core.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(Index)
index = new(core.Index)
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
@ -273,7 +273,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil
}
func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset string) string {
func (db *mssql) CreateTablSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string
if tableName == "" {
tableName = table.Name
@ -307,6 +307,6 @@ func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset str
return sql
}
func (db *mssql) Filters() []Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}}
func (db *mssql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}}
}

View File

@ -1,4 +1,4 @@
package dialects
package xorm
import (
"crypto/tls"
@ -8,15 +8,15 @@ import (
"strings"
"time"
. "github.com/go-xorm/core"
"github.com/go-xorm/core"
)
func init() {
RegisterDialect("mysql", &mysql{})
}
// func init() {
// RegisterDialect("mysql", &mysql{})
// }
type mysql struct {
Base
core.Base
net string
addr string
params map[string]string
@ -28,30 +28,30 @@ type mysql struct {
clientFoundRows bool
}
func (db *mysql) Init(uri *Uri, drivername, dataSourceName string) error {
func (db *mysql) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
}
func (db *mysql) SqlType(c *Column) string {
func (db *mysql) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
case Bool:
res = TinyInt
case core.Bool:
res = core.TinyInt
c.Length = 1
case Serial:
case core.Serial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = Int
case BigSerial:
res = core.Int
case core.BigSerial:
c.IsAutoIncrement = true
c.IsPrimaryKey = true
c.Nullable = false
res = BigInt
case Bytea:
res = Blob
case TimeStampz:
res = Char
res = core.BigInt
case core.Bytea:
res = core.Blob
case core.TimeStampz:
res = core.Char
c.Length = 64
default:
res = t
@ -110,11 +110,11 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
return sql, args
}
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{db.DbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, nil, err
}
@ -123,10 +123,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
if err != nil {
return nil, nil, err
}
cols := make(map[string]*Column)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(Column)
col := new(core.Column)
col.Indexes = make(map[string]bool)
var columnName, isNullable, colType, colKey, extra string
@ -164,8 +164,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
colType = strings.ToUpper(colName)
col.Length = len1
col.Length2 = len2
if _, ok := SqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2}
if _, ok := core.SqlTypes[colType]; ok {
col.SQLType = core.SQLType{colType, len1, len2}
} else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
}
@ -192,10 +192,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
return colSeq, cols, nil
}
func (db *mysql) GetTables() ([]*Table, error) {
func (db *mysql) GetTables() ([]*core.Table, error) {
args := []interface{}{db.DbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -205,9 +205,9 @@ func (db *mysql) GetTables() ([]*Table, error) {
return nil, err
}
tables := make([]*Table, 0)
tables := make([]*core.Table, 0)
for rows.Next() {
table := NewEmptyTable()
table := core.NewEmptyTable()
var name, engine, tableRows string
var autoIncr *string
err = rows.Scan(&name, &engine, &tableRows, &autoIncr)
@ -221,10 +221,10 @@ func (db *mysql) GetTables() ([]*Table, error) {
return tables, nil
}
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{db.DbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -234,7 +234,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, err
}
indexes := make(map[string]*Index, 0)
indexes := make(map[string]*core.Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, nonUnique string
@ -248,9 +248,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
}
if "YES" == nonUnique || nonUnique == "1" {
indexType = IndexType
indexType = core.IndexType
} else {
indexType = UniqueType
indexType = core.UniqueType
}
colName = strings.Trim(colName, "` ")
@ -259,10 +259,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
indexName = indexName[5+len(tableName) : len(indexName)]
}
var index *Index
var index *core.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(Index)
index = new(core.Index)
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
@ -272,6 +272,6 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
return indexes, nil
}
func (db *mysql) Filters() []Filter {
return []Filter{&IdFilter{}}
func (db *mysql) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}}
}

View File

@ -1,4 +1,4 @@
package dialects
package xorm
import (
"errors"
@ -6,37 +6,37 @@ import (
"strconv"
"strings"
. "github.com/go-xorm/core"
"github.com/go-xorm/core"
)
func init() {
RegisterDialect("oracle", &oracle{})
}
// func init() {
// RegisterDialect("oracle", &oracle{})
// }
type oracle struct {
Base
core.Base
}
func (db *oracle) Init(uri *Uri, drivername, dataSourceName string) error {
func (db *oracle) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
}
func (db *oracle) SqlType(c *Column) string {
func (db *oracle) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial:
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial:
return "NUMBER"
case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea:
return Blob
case Time, DateTime, TimeStamp:
res = TimeStamp
case TimeStampz:
case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea:
return core.Blob
case core.Time, core.DateTime, core.TimeStamp:
res = core.TimeStamp
case core.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE"
case Float, Double, Numeric, Decimal:
case core.Float, core.Double, core.Numeric, core.Decimal:
res = "NUMBER"
case Text, MediumText, LongText:
case core.Text, core.MediumText, core.LongText:
res = "CLOB"
case Char, Varchar, TinyText:
case core.Char, core.Varchar, core.TinyText:
return "VARCHAR2"
default:
res = t
@ -93,12 +93,12 @@ func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface
" AND column_name = ?", args
}
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) {
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{strings.ToUpper(tableName)}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, nil, err
}
@ -109,10 +109,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
}
defer rows.Close()
cols := make(map[string]*Column)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(Column)
col := new(core.Column)
col.Indexes = make(map[string]bool)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale string
@ -135,13 +135,13 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
switch dataType {
case "VARCHAR2":
col.SQLType = SQLType{Varchar, 0, 0}
col.SQLType = core.SQLType{core.Varchar, 0, 0}
case "TIMESTAMP WITH TIME ZONE":
col.SQLType = SQLType{TimeStampz, 0, 0}
col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
default:
col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0}
col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
}
if _, ok := SqlTypes[col.SQLType.Name]; !ok {
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
}
@ -163,10 +163,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
return colSeq, cols, nil
}
func (db *oracle) GetTables() ([]*Table, error) {
func (db *oracle) GetTables() ([]*core.Table, error) {
args := []interface{}{}
s := "SELECT table_name FROM user_tables"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -176,9 +176,9 @@ func (db *oracle) GetTables() ([]*Table, error) {
return nil, err
}
tables := make([]*Table, 0)
tables := make([]*core.Table, 0)
for rows.Next() {
table := NewEmptyTable()
table := core.NewEmptyTable()
err = rows.Scan(&table.Name)
if err != nil {
return nil, err
@ -189,12 +189,12 @@ func (db *oracle) GetTables() ([]*Table, error) {
return tables, nil
}
func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName}
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -205,7 +205,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
}
defer rows.Close()
indexes := make(map[string]*Index, 0)
indexes := make(map[string]*core.Index, 0)
for rows.Next() {
var indexType int
var indexName, colName, uniqueness string
@ -218,15 +218,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
indexName = strings.Trim(indexName, `" `)
if uniqueness == "UNIQUE" {
indexType = UniqueType
indexType = core.UniqueType
} else {
indexType = IndexType
indexType = core.IndexType
}
var index *Index
var index *core.Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(Index)
index = new(core.Index)
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
@ -240,7 +240,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
type OracleSeqFilter struct {
}
func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
func (s *OracleSeqFilter) Do(sql string, dialect core.Dialect, table *core.Table) string {
counts := strings.Count(sql, "?")
for i := 1; i <= counts; i++ {
newstr := ":" + fmt.Sprintf("%v", i)
@ -249,6 +249,6 @@ func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
return sql
}
func (db *oracle) Filters() []Filter {
return []Filter{&QuoteFilter{}, &OracleSeqFilter{}, &IdFilter{}}
func (db *oracle) Filters() []core.Filter {
return []core.Filter{&core.QuoteFilter{}, &OracleSeqFilter{}, &core.IdFilter{}}
}

View File

@ -1,4 +1,4 @@
package dialects
package xorm
import (
"errors"
@ -6,53 +6,53 @@ import (
"strconv"
"strings"
. "github.com/go-xorm/core"
"github.com/go-xorm/core"
)
func init() {
RegisterDialect("postgres", &postgres{})
}
// func init() {
// RegisterDialect("postgres", &postgres{})
// }
type postgres struct {
Base
core.Base
}
func (db *postgres) Init(uri *Uri, drivername, dataSourceName string) error {
func (db *postgres) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
}
func (db *postgres) SqlType(c *Column) string {
func (db *postgres) SqlType(c *core.Column) string {
var res string
switch t := c.SQLType.Name; t {
case TinyInt:
res = SmallInt
case core.TinyInt:
res = core.SmallInt
return res
case MediumInt, Int, Integer:
case core.MediumInt, core.Int, core.Integer:
if c.IsAutoIncrement {
return Serial
return core.Serial
}
return Integer
case Serial, BigSerial:
return core.Integer
case core.Serial, core.BigSerial:
c.IsAutoIncrement = true
c.Nullable = false
res = t
case Binary, VarBinary:
return Bytea
case DateTime:
res = TimeStamp
case TimeStampz:
case core.Binary, core.VarBinary:
return core.Bytea
case core.DateTime:
res = core.TimeStamp
case core.TimeStampz:
return "timestamp with time zone"
case Float:
res = Real
case TinyText, MediumText, LongText:
res = Text
case Blob, TinyBlob, MediumBlob, LongBlob:
return Bytea
case Double:
case core.Float:
res = core.Real
case core.TinyText, core.MediumText, core.LongText:
res = core.Text
case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob:
return core.Bytea
case core.Double:
return "DOUBLE PRECISION"
default:
if c.IsAutoIncrement {
return Serial
return core.Serial
}
res = t
}
@ -108,11 +108,11 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa
" AND column_name = ?", args
}
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, nil, err
}
@ -121,11 +121,11 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
if err != nil {
return nil, nil, err
}
cols := make(map[string]*Column)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)
for rows.Next() {
col := new(Column)
col := new(core.Column)
col.Indexes = make(map[string]bool)
var colName, isNullable, dataType string
@ -161,21 +161,21 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
switch dataType {
case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0}
col.SQLType = core.SQLType{core.Varchar, 0, 0}
case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0}
col.SQLType = core.SQLType{core.DateTime, 0, 0}
case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0}
col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
case "double precision":
col.SQLType = SQLType{Double, 0, 0}
col.SQLType = core.SQLType{core.Double, 0, 0}
case "boolean":
col.SQLType = SQLType{Bool, 0, 0}
col.SQLType = core.SQLType{core.Bool, 0, 0}
case "time without time zone":
col.SQLType = SQLType{Time, 0, 0}
col.SQLType = core.SQLType{core.Time, 0, 0}
default:
col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0}
col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
}
if _, ok := SqlTypes[col.SQLType.Name]; !ok {
if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
}
@ -197,10 +197,10 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
return colSeq, cols, nil
}
func (db *postgres) GetTables() ([]*Table, error) {
func (db *postgres) GetTables() ([]*core.Table, error) {
args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -210,9 +210,9 @@ func (db *postgres) GetTables() ([]*Table, error) {
return nil, err
}
tables := make([]*Table, 0)
tables := make([]*core.Table, 0)
for rows.Next() {
table := NewEmptyTable()
table := core.NewEmptyTable()
var name string
err = rows.Scan(&name)
if err != nil {
@ -224,11 +224,11 @@ func (db *postgres) GetTables() ([]*Table, error) {
return tables, nil
}
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName}
s := "SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -238,7 +238,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, err
}
indexes := make(map[string]*Index, 0)
indexes := make(map[string]*core.Index, 0)
for rows.Next() {
var indexType int
var indexName, indexdef string
@ -250,9 +250,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
indexName = strings.Trim(indexName, `" `)
if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") {
indexType = UniqueType
indexType = core.UniqueType
} else {
indexType = IndexType
indexType = core.IndexType
}
cs := strings.Split(indexdef, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
@ -267,7 +267,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
}
}
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
}
@ -280,7 +280,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
type PgSeqFilter struct {
}
func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
func (s *PgSeqFilter) Do(sql string, dialect core.Dialect, table *core.Table) string {
segs := strings.Split(sql, "?")
size := len(segs)
res := ""
@ -293,6 +293,6 @@ func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
return res
}
func (db *postgres) Filters() []Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}, &PgSeqFilter{}}
func (db *postgres) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &PgSeqFilter{}}
}

View File

@ -1,44 +1,44 @@
package dialects
package xorm
import (
"strings"
. "github.com/go-xorm/core"
"github.com/go-xorm/core"
)
func init() {
RegisterDialect("sqlite3", &sqlite3{})
}
// func init() {
// RegisterDialect("sqlite3", &sqlite3{})
// }
type sqlite3 struct {
Base
core.Base
}
func (db *sqlite3) Init(uri *Uri, drivername, dataSourceName string) error {
func (db *sqlite3) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName)
}
func (db *sqlite3) SqlType(c *Column) string {
func (db *sqlite3) SqlType(c *core.Column) string {
switch t := c.SQLType.Name; t {
case Date, DateTime, TimeStamp, Time:
return Numeric
case TimeStampz:
return Text
case Char, Varchar, TinyText, Text, MediumText, LongText:
return Text
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
return Integer
case Float, Double, Real:
return Real
case Decimal, Numeric:
return Numeric
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
return Blob
case Serial, BigSerial:
case core.Date, core.DateTime, core.TimeStamp, core.Time:
return core.Numeric
case core.TimeStampz:
return core.Text
case core.Char, core.Varchar, core.TinyText, core.Text, core.MediumText, core.LongText:
return core.Text
case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool:
return core.Integer
case core.Float, core.Double, core.Real:
return core.Real
case core.Decimal, core.Numeric:
return core.Numeric
case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary:
return core.Blob
case core.Serial, core.BigSerial:
c.IsPrimaryKey = true
c.IsAutoIncrement = true
c.Nullable = false
return Integer
return core.Integer
default:
return t
}
@ -84,10 +84,10 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac
return sql, args
}
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, nil, err
}
@ -110,11 +110,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e
nStart := strings.Index(name, "(")
nEnd := strings.Index(name, ")")
colCreates := strings.Split(name[nStart+1:nEnd], ",")
cols := make(map[string]*Column)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)
for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column)
col := new(core.Column)
col.Indexes = make(map[string]bool)
col.Nullable = true
for idx, field := range fields {
@ -122,7 +122,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e
col.Name = strings.Trim(field, "`[] ")
continue
} else if idx == 1 {
col.SQLType = SQLType{field, 0, 0}
col.SQLType = core.SQLType{field, 0, 0}
}
switch field {
case "PRIMARY":
@ -143,11 +143,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e
return colSeq, cols, nil
}
func (db *sqlite3) GetTables() ([]*Table, error) {
func (db *sqlite3) GetTables() ([]*core.Table, error) {
args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -158,9 +158,9 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
}
defer rows.Close()
tables := make([]*Table, 0)
tables := make([]*core.Table, 0)
for rows.Next() {
table := NewEmptyTable()
table := core.NewEmptyTable()
err = rows.Scan(&table.Name)
if err != nil {
return nil, err
@ -173,10 +173,10 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
return tables, nil
}
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName())
cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil {
return nil, err
}
@ -187,7 +187,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
}
defer rows.Close()
indexes := make(map[string]*Index, 0)
indexes := make(map[string]*core.Index, 0)
for rows.Next() {
var sql string
err = rows.Scan(&sql)
@ -199,7 +199,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
continue
}
index := new(Index)
index := new(core.Index)
nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON")
if nNStart == -1 || nNEnd == -1 {
@ -215,9 +215,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
}
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType
index.Type = core.UniqueType
} else {
index.Type = IndexType
index.Type = core.IndexType
}
nStart := strings.Index(sql, "(")
@ -234,6 +234,6 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
return indexes, nil
}
func (db *sqlite3) Filters() []Filter {
return []Filter{&IdFilter{}}
func (db *sqlite3) Filters() []core.Filter {
return []core.Filter{&core.IdFilter{}}
}

23
xorm.go
View File

@ -1,6 +1,7 @@
package xorm
import (
"database/sql"
"errors"
"fmt"
"os"
@ -11,7 +12,6 @@ import (
"github.com/go-xorm/core"
"github.com/go-xorm/xorm/caches"
_ "github.com/go-xorm/xorm/dialects"
_ "github.com/go-xorm/xorm/drivers"
)
@ -19,6 +19,27 @@ const (
Version string = "0.4"
)
func init() {
provided_dialects := map[string]struct {
dbType core.DbType
get func() core.Dialect
}{
"odbc": {"mssql", func() core.Dialect { return &mssql{} }},
"mysql": {"mysql", func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Dialect { return &mysql{} }},
"oci8": {"oracle", func() core.Dialect { return &oracle{} }},
"postgres": {"postgres", func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Dialect { return &sqlite3{} }},
}
for k, v := range provided_dialects {
_, err := sql.Open(string(k), "")
if err == nil {
core.RegisterDialect(v.dbType, v.get())
}
}
}
func close(engine *Engine) {
engine.Close()
}