flattened dialects dir and register db dialect for assocaited registered driver

This commit is contained in:
Nash Tsai 2014-04-11 21:06:11 +08:00
parent a0919b5371
commit 81c947b61b
6 changed files with 241 additions and 220 deletions

View File

@ -1,4 +1,4 @@
package dialects package xorm
import ( import (
"errors" "errors"
@ -6,58 +6,58 @@ import (
"strconv" "strconv"
"strings" "strings"
. "github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func init() { // func init() {
RegisterDialect("mssql", &mssql{}) // RegisterDialect("mssql", &mssql{})
} // }
type mssql struct { type mssql struct {
Base core.Base
} }
func (db *mssql) Init(uri *Uri, drivername, dataSourceName string) error { func (db *mssql) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName) return db.Base.Init(db, uri, drivername, dataSourceName)
} }
func (db *mssql) SqlType(c *Column) string { func (db *mssql) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bool: case core.Bool:
res = TinyInt res = core.TinyInt
case Serial: case core.Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = Int res = core.Int
case BigSerial: case core.BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = BigInt res = core.BigInt
case Bytea, Blob, Binary, TinyBlob, MediumBlob, LongBlob: case core.Bytea, core.Blob, core.Binary, core.TinyBlob, core.MediumBlob, core.LongBlob:
res = VarBinary res = core.VarBinary
if c.Length == 0 { if c.Length == 0 {
c.Length = 50 c.Length = 50
} }
case TimeStamp: case core.TimeStamp:
res = DateTime res = core.DateTime
case TimeStampz: case core.TimeStampz:
res = "DATETIMEOFFSET" res = "DATETIMEOFFSET"
c.Length = 7 c.Length = 7
case MediumInt: case core.MediumInt:
res = Int res = core.Int
case MediumText, TinyText, LongText: case core.MediumText, core.TinyText, core.LongText:
res = Text res = core.Text
case Double: case core.Double:
res = Real res = core.Real
default: default:
res = t res = t
} }
if res == Int { if res == core.Int {
return Int return core.Int
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
@ -118,12 +118,12 @@ func (db *mssql) TableCheckSql(tableName string) (string, []interface{}) {
return sql, args return sql, args
} }
func (db *mssql) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{} args := []interface{}{}
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale
from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id
where a.object_id=object_id('` + tableName + `')` where a.object_id=object_id('` + tableName + `')`
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -133,7 +133,7 @@ where a.object_id=object_id('` + tableName + `')`
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*core.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
var name, ctype, precision, scale string var name, ctype, precision, scale string
@ -143,7 +143,7 @@ where a.object_id=object_id('` + tableName + `')`
return nil, nil, err return nil, nil, err
} }
col := new(Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Length = maxLen col.Length = maxLen
col.Name = strings.Trim(name, "` ") col.Name = strings.Trim(name, "` ")
@ -151,14 +151,14 @@ where a.object_id=object_id('` + tableName + `')`
ct := strings.ToUpper(ctype) ct := strings.ToUpper(ctype)
switch ct { switch ct {
case "DATETIMEOFFSET": case "DATETIMEOFFSET":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
case "NVARCHAR": case "NVARCHAR":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = core.SQLType{core.Varchar, 0, 0}
case "IMAGE": case "IMAGE":
col.SQLType = SQLType{VarBinary, 0, 0} col.SQLType = core.SQLType{core.VarBinary, 0, 0}
default: default:
if _, ok := SqlTypes[ct]; ok { if _, ok := core.SqlTypes[ct]; ok {
col.SQLType = SQLType{ct, 0, 0} col.SQLType = core.SQLType{ct, 0, 0}
} else { } else {
return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v", return nil, nil, errors.New(fmt.Sprintf("unknow colType %v for %v - %v",
ct, tableName, col.Name)) ct, tableName, col.Name))
@ -180,10 +180,10 @@ where a.object_id=object_id('` + tableName + `')`
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mssql) GetTables() ([]*Table, error) { func (db *mssql) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{}
s := `select name from sysobjects where xtype ='U'` s := `select name from sysobjects where xtype ='U'`
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -193,9 +193,9 @@ func (db *mssql) GetTables() ([]*Table, error) {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*core.Table, 0)
for rows.Next() { for rows.Next() {
table := NewEmptyTable() table := core.NewEmptyTable()
var name string var name string
err = rows.Scan(&name) err = rows.Scan(&name)
if err != nil { if err != nil {
@ -207,7 +207,7 @@ func (db *mssql) GetTables() ([]*Table, error) {
return tables, nil return tables, nil
} }
func (db *mssql) GetIndexes(tableName string) (map[string]*Index, error) { func (db *mssql) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := `SELECT s := `SELECT
IXS.NAME AS [INDEX_NAME], IXS.NAME AS [INDEX_NAME],
@ -223,7 +223,7 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID
AND IXCS.COLUMN_ID=C.COLUMN_ID AND IXCS.COLUMN_ID=C.COLUMN_ID
WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
` `
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -233,7 +233,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, colName, isUnique string var indexName, colName, isUnique string
@ -249,9 +249,9 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
} }
if i { if i {
indexType = UniqueType indexType = core.UniqueType
} else { } else {
indexType = IndexType indexType = core.IndexType
} }
colName = strings.Trim(colName, "` ") colName = strings.Trim(colName, "` ")
@ -260,10 +260,10 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
var index *Index var index *core.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(core.Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
@ -273,7 +273,7 @@ WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =?
return indexes, nil return indexes, nil
} }
func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset string) string { func (db *mssql) CreateTablSql(table *core.Table, tableName, storeEngine, charset string) string {
var sql string var sql string
if tableName == "" { if tableName == "" {
tableName = table.Name tableName = table.Name
@ -307,6 +307,6 @@ func (db *mssql) CreateTablSql(table *Table, tableName, storeEngine, charset str
return sql return sql
} }
func (db *mssql) Filters() []Filter { func (db *mssql) Filters() []core.Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}} return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}}
} }

View File

@ -1,4 +1,4 @@
package dialects package xorm
import ( import (
"crypto/tls" "crypto/tls"
@ -8,15 +8,15 @@ import (
"strings" "strings"
"time" "time"
. "github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func init() { // func init() {
RegisterDialect("mysql", &mysql{}) // RegisterDialect("mysql", &mysql{})
} // }
type mysql struct { type mysql struct {
Base core.Base
net string net string
addr string addr string
params map[string]string params map[string]string
@ -28,30 +28,30 @@ type mysql struct {
clientFoundRows bool clientFoundRows bool
} }
func (db *mysql) Init(uri *Uri, drivername, dataSourceName string) error { func (db *mysql) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName) return db.Base.Init(db, uri, drivername, dataSourceName)
} }
func (db *mysql) SqlType(c *Column) string { func (db *mysql) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bool: case core.Bool:
res = TinyInt res = core.TinyInt
c.Length = 1 c.Length = 1
case Serial: case core.Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = Int res = core.Int
case BigSerial: case core.BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = BigInt res = core.BigInt
case Bytea: case core.Bytea:
res = Blob res = core.Blob
case TimeStampz: case core.TimeStampz:
res = Char res = core.Char
c.Length = 64 c.Length = 64
default: default:
res = t res = t
@ -110,11 +110,11 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{db.DbName, tableName} args := []interface{}{db.DbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -123,10 +123,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*core.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
col := new(Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
var columnName, isNullable, colType, colKey, extra string var columnName, isNullable, colType, colKey, extra string
@ -164,8 +164,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
colType = strings.ToUpper(colName) colType = strings.ToUpper(colName)
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := SqlTypes[colType]; ok { if _, ok := core.SqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2} col.SQLType = core.SQLType{colType, len1, len2}
} else { } else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
} }
@ -192,10 +192,10 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, err
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*Table, error) { func (db *mysql) GetTables() ([]*core.Table, error) {
args := []interface{}{db.DbName} args := []interface{}{db.DbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -205,9 +205,9 @@ func (db *mysql) GetTables() ([]*Table, error) {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*core.Table, 0)
for rows.Next() { for rows.Next() {
table := NewEmptyTable() table := core.NewEmptyTable()
var name, engine, tableRows string var name, engine, tableRows string
var autoIncr *string var autoIncr *string
err = rows.Scan(&name, &engine, &tableRows, &autoIncr) err = rows.Scan(&name, &engine, &tableRows, &autoIncr)
@ -221,10 +221,10 @@ func (db *mysql) GetTables() ([]*Table, error) {
return tables, nil return tables, nil
} }
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{db.DbName, tableName} args := []interface{}{db.DbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -234,7 +234,7 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, colName, nonUnique string var indexName, colName, nonUnique string
@ -248,9 +248,9 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
} }
if "YES" == nonUnique || nonUnique == "1" { if "YES" == nonUnique || nonUnique == "1" {
indexType = IndexType indexType = core.IndexType
} else { } else {
indexType = UniqueType indexType = core.UniqueType
} }
colName = strings.Trim(colName, "` ") colName = strings.Trim(colName, "` ")
@ -259,10 +259,10 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
var index *Index var index *core.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(core.Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
@ -272,6 +272,6 @@ func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
return indexes, nil return indexes, nil
} }
func (db *mysql) Filters() []Filter { func (db *mysql) Filters() []core.Filter {
return []Filter{&IdFilter{}} return []core.Filter{&core.IdFilter{}}
} }

View File

@ -1,4 +1,4 @@
package dialects package xorm
import ( import (
"errors" "errors"
@ -6,37 +6,37 @@ import (
"strconv" "strconv"
"strings" "strings"
. "github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func init() { // func init() {
RegisterDialect("oracle", &oracle{}) // RegisterDialect("oracle", &oracle{})
} // }
type oracle struct { type oracle struct {
Base core.Base
} }
func (db *oracle) Init(uri *Uri, drivername, dataSourceName string) error { func (db *oracle) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName) return db.Base.Init(db, uri, drivername, dataSourceName)
} }
func (db *oracle) SqlType(c *Column) string { func (db *oracle) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool, Serial, BigSerial: case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool, core.Serial, core.BigSerial:
return "NUMBER" return "NUMBER"
case Binary, VarBinary, Blob, TinyBlob, MediumBlob, LongBlob, Bytea: case core.Binary, core.VarBinary, core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob, core.Bytea:
return Blob return core.Blob
case Time, DateTime, TimeStamp: case core.Time, core.DateTime, core.TimeStamp:
res = TimeStamp res = core.TimeStamp
case TimeStampz: case core.TimeStampz:
res = "TIMESTAMP WITH TIME ZONE" res = "TIMESTAMP WITH TIME ZONE"
case Float, Double, Numeric, Decimal: case core.Float, core.Double, core.Numeric, core.Decimal:
res = "NUMBER" res = "NUMBER"
case Text, MediumText, LongText: case core.Text, core.MediumText, core.LongText:
res = "CLOB" res = "CLOB"
case Char, Varchar, TinyText: case core.Char, core.Varchar, core.TinyText:
return "VARCHAR2" return "VARCHAR2"
default: default:
res = t res = t
@ -93,12 +93,12 @@ func (db *oracle) ColumnCheckSql(tableName, colName string) (string, []interface
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{strings.ToUpper(tableName)} args := []interface{}{strings.ToUpper(tableName)}
s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," +
"nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -109,10 +109,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
} }
defer rows.Close() defer rows.Close()
cols := make(map[string]*Column) cols := make(map[string]*core.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
col := new(Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
var colName, colDefault, nullable, dataType, dataPrecision, dataScale string var colName, colDefault, nullable, dataType, dataPrecision, dataScale string
@ -135,13 +135,13 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
switch dataType { switch dataType {
case "VARCHAR2": case "VARCHAR2":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = core.SQLType{core.Varchar, 0, 0}
case "TIMESTAMP WITH TIME ZONE": case "TIMESTAMP WITH TIME ZONE":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
default: default:
col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0} col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
} }
if _, ok := SqlTypes[col.SQLType.Name]; !ok { if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
} }
@ -163,10 +163,10 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*Column, er
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *oracle) GetTables() ([]*Table, error) { func (db *oracle) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT table_name FROM user_tables" s := "SELECT table_name FROM user_tables"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -176,9 +176,9 @@ func (db *oracle) GetTables() ([]*Table, error) {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*core.Table, 0)
for rows.Next() { for rows.Next() {
table := NewEmptyTable() table := core.NewEmptyTable()
err = rows.Scan(&table.Name) err = rows.Scan(&table.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -189,12 +189,12 @@ func (db *oracle) GetTables() ([]*Table, error) {
return tables, nil return tables, nil
} }
func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) { func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " +
"WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -205,7 +205,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, colName, uniqueness string var indexName, colName, uniqueness string
@ -218,15 +218,15 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
indexName = strings.Trim(indexName, `" `) indexName = strings.Trim(indexName, `" `)
if uniqueness == "UNIQUE" { if uniqueness == "UNIQUE" {
indexType = UniqueType indexType = core.UniqueType
} else { } else {
indexType = IndexType indexType = core.IndexType
} }
var index *Index var index *core.Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(core.Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
@ -240,7 +240,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*Index, error) {
type OracleSeqFilter struct { type OracleSeqFilter struct {
} }
func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string { func (s *OracleSeqFilter) Do(sql string, dialect core.Dialect, table *core.Table) string {
counts := strings.Count(sql, "?") counts := strings.Count(sql, "?")
for i := 1; i <= counts; i++ { for i := 1; i <= counts; i++ {
newstr := ":" + fmt.Sprintf("%v", i) newstr := ":" + fmt.Sprintf("%v", i)
@ -249,6 +249,6 @@ func (s *OracleSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
return sql return sql
} }
func (db *oracle) Filters() []Filter { func (db *oracle) Filters() []core.Filter {
return []Filter{&QuoteFilter{}, &OracleSeqFilter{}, &IdFilter{}} return []core.Filter{&core.QuoteFilter{}, &OracleSeqFilter{}, &core.IdFilter{}}
} }

View File

@ -1,4 +1,4 @@
package dialects package xorm
import ( import (
"errors" "errors"
@ -6,53 +6,53 @@ import (
"strconv" "strconv"
"strings" "strings"
. "github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func init() { // func init() {
RegisterDialect("postgres", &postgres{}) // RegisterDialect("postgres", &postgres{})
} // }
type postgres struct { type postgres struct {
Base core.Base
} }
func (db *postgres) Init(uri *Uri, drivername, dataSourceName string) error { func (db *postgres) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName) return db.Base.Init(db, uri, drivername, dataSourceName)
} }
func (db *postgres) SqlType(c *Column) string { func (db *postgres) SqlType(c *core.Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case TinyInt: case core.TinyInt:
res = SmallInt res = core.SmallInt
return res return res
case MediumInt, Int, Integer: case core.MediumInt, core.Int, core.Integer:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return core.Serial
} }
return Integer return core.Integer
case Serial, BigSerial: case core.Serial, core.BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
res = t res = t
case Binary, VarBinary: case core.Binary, core.VarBinary:
return Bytea return core.Bytea
case DateTime: case core.DateTime:
res = TimeStamp res = core.TimeStamp
case TimeStampz: case core.TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case Float: case core.Float:
res = Real res = core.Real
case TinyText, MediumText, LongText: case core.TinyText, core.MediumText, core.LongText:
res = Text res = core.Text
case Blob, TinyBlob, MediumBlob, LongBlob: case core.Blob, core.TinyBlob, core.MediumBlob, core.LongBlob:
return Bytea return core.Bytea
case Double: case core.Double:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return core.Serial
} }
res = t res = t
} }
@ -108,11 +108,11 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -121,11 +121,11 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*core.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for rows.Next() { for rows.Next() {
col := new(Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
var colName, isNullable, dataType string var colName, isNullable, dataType string
@ -161,21 +161,21 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
switch dataType { switch dataType {
case "character varying", "character": case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = core.SQLType{core.Varchar, 0, 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0} col.SQLType = core.SQLType{core.DateTime, 0, 0}
case "timestamp with time zone": case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = core.SQLType{core.TimeStampz, 0, 0}
case "double precision": case "double precision":
col.SQLType = SQLType{Double, 0, 0} col.SQLType = core.SQLType{core.Double, 0, 0}
case "boolean": case "boolean":
col.SQLType = SQLType{Bool, 0, 0} col.SQLType = core.SQLType{core.Bool, 0, 0}
case "time without time zone": case "time without time zone":
col.SQLType = SQLType{Time, 0, 0} col.SQLType = core.SQLType{core.Time, 0, 0}
default: default:
col.SQLType = SQLType{strings.ToUpper(dataType), 0, 0} col.SQLType = core.SQLType{strings.ToUpper(dataType), 0, 0}
} }
if _, ok := SqlTypes[col.SQLType.Name]; !ok { if _, ok := core.SqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", dataType))
} }
@ -197,10 +197,10 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column,
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'" s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -210,9 +210,9 @@ func (db *postgres) GetTables() ([]*Table, error) {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*core.Table, 0)
for rows.Next() { for rows.Next() {
table := NewEmptyTable() table := core.NewEmptyTable()
var name string var name string
err = rows.Scan(&name) err = rows.Scan(&name)
if err != nil { if err != nil {
@ -224,11 +224,11 @@ func (db *postgres) GetTables() ([]*Table, error) {
return tables, nil return tables, nil
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" s := "SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -238,7 +238,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var indexType int var indexType int
var indexName, indexdef string var indexName, indexdef string
@ -250,9 +250,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
indexName = strings.Trim(indexName, `" `) indexName = strings.Trim(indexName, `" `)
if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") {
indexType = UniqueType indexType = core.UniqueType
} else { } else {
indexType = IndexType indexType = core.IndexType
} }
cs := strings.Split(indexdef, "(") cs := strings.Split(indexdef, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
@ -267,7 +267,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
} }
} }
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &core.Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
@ -280,7 +280,7 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
type PgSeqFilter struct { type PgSeqFilter struct {
} }
func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string { func (s *PgSeqFilter) Do(sql string, dialect core.Dialect, table *core.Table) string {
segs := strings.Split(sql, "?") segs := strings.Split(sql, "?")
size := len(segs) size := len(segs)
res := "" res := ""
@ -293,6 +293,6 @@ func (s *PgSeqFilter) Do(sql string, dialect Dialect, table *Table) string {
return res return res
} }
func (db *postgres) Filters() []Filter { func (db *postgres) Filters() []core.Filter {
return []Filter{&IdFilter{}, &QuoteFilter{}, &PgSeqFilter{}} return []core.Filter{&core.IdFilter{}, &core.QuoteFilter{}, &PgSeqFilter{}}
} }

View File

@ -1,44 +1,44 @@
package dialects package xorm
import ( import (
"strings" "strings"
. "github.com/go-xorm/core" "github.com/go-xorm/core"
) )
func init() { // func init() {
RegisterDialect("sqlite3", &sqlite3{}) // RegisterDialect("sqlite3", &sqlite3{})
} // }
type sqlite3 struct { type sqlite3 struct {
Base core.Base
} }
func (db *sqlite3) Init(uri *Uri, drivername, dataSourceName string) error { func (db *sqlite3) Init(uri *core.Uri, drivername, dataSourceName string) error {
return db.Base.Init(db, uri, drivername, dataSourceName) return db.Base.Init(db, uri, drivername, dataSourceName)
} }
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *core.Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Date, DateTime, TimeStamp, Time: case core.Date, core.DateTime, core.TimeStamp, core.Time:
return Numeric return core.Numeric
case TimeStampz: case core.TimeStampz:
return Text return core.Text
case Char, Varchar, TinyText, Text, MediumText, LongText: case core.Char, core.Varchar, core.TinyText, core.Text, core.MediumText, core.LongText:
return Text return core.Text
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool:
return Integer return core.Integer
case Float, Double, Real: case core.Float, core.Double, core.Real:
return Real return core.Real
case Decimal, Numeric: case core.Decimal, core.Numeric:
return Numeric return core.Numeric
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: case core.TinyBlob, core.Blob, core.MediumBlob, core.LongBlob, core.Bytea, core.Binary, core.VarBinary:
return Blob return core.Blob
case Serial, BigSerial: case core.Serial, core.BigSerial:
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
return Integer return core.Integer
default: default:
return t return t
} }
@ -84,10 +84,10 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac
return sql, args return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -110,11 +110,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e
nStart := strings.Index(name, "(") nStart := strings.Index(name, "(")
nEnd := strings.Index(name, ")") nEnd := strings.Index(name, ")")
colCreates := strings.Split(name[nStart+1:nEnd], ",") colCreates := strings.Split(name[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*core.Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(core.Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true col.Nullable = true
for idx, field := range fields { for idx, field := range fields {
@ -122,7 +122,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
continue continue
} else if idx == 1 { } else if idx == 1 {
col.SQLType = SQLType{field, 0, 0} col.SQLType = core.SQLType{field, 0, 0}
} }
switch field { switch field {
case "PRIMARY": case "PRIMARY":
@ -143,11 +143,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, e
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*core.Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -158,9 +158,9 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
} }
defer rows.Close() defer rows.Close()
tables := make([]*Table, 0) tables := make([]*core.Table, 0)
for rows.Next() { for rows.Next() {
table := NewEmptyTable() table := core.NewEmptyTable()
err = rows.Scan(&table.Name) err = rows.Scan(&table.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -173,10 +173,10 @@ func (db *sqlite3) GetTables() ([]*Table, error) {
return tables, nil return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := Open(db.DriverName(), db.DataSourceName()) cnn, err := core.Open(db.DriverName(), db.DataSourceName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -187,7 +187,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
} }
defer rows.Close() defer rows.Close()
indexes := make(map[string]*Index, 0) indexes := make(map[string]*core.Index, 0)
for rows.Next() { for rows.Next() {
var sql string var sql string
err = rows.Scan(&sql) err = rows.Scan(&sql)
@ -199,7 +199,7 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
continue continue
} }
index := new(Index) index := new(core.Index)
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
if nNStart == -1 || nNEnd == -1 { if nNStart == -1 || nNEnd == -1 {
@ -215,9 +215,9 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
} }
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType index.Type = core.UniqueType
} else { } else {
index.Type = IndexType index.Type = core.IndexType
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
@ -234,6 +234,6 @@ func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
return indexes, nil return indexes, nil
} }
func (db *sqlite3) Filters() []Filter { func (db *sqlite3) Filters() []core.Filter {
return []Filter{&IdFilter{}} return []core.Filter{&core.IdFilter{}}
} }

23
xorm.go
View File

@ -1,6 +1,7 @@
package xorm package xorm
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@ -11,7 +12,6 @@ import (
"github.com/go-xorm/core" "github.com/go-xorm/core"
"github.com/go-xorm/xorm/caches" "github.com/go-xorm/xorm/caches"
_ "github.com/go-xorm/xorm/dialects"
_ "github.com/go-xorm/xorm/drivers" _ "github.com/go-xorm/xorm/drivers"
) )
@ -19,6 +19,27 @@ const (
Version string = "0.4" Version string = "0.4"
) )
func init() {
provided_dialects := map[string]struct {
dbType core.DbType
get func() core.Dialect
}{
"odbc": {"mssql", func() core.Dialect { return &mssql{} }},
"mysql": {"mysql", func() core.Dialect { return &mysql{} }},
"mymysql": {"mysql", func() core.Dialect { return &mysql{} }},
"oci8": {"oracle", func() core.Dialect { return &oracle{} }},
"postgres": {"postgres", func() core.Dialect { return &postgres{} }},
"sqlite3": {"sqlite3", func() core.Dialect { return &sqlite3{} }},
}
for k, v := range provided_dialects {
_, err := sql.Open(string(k), "")
if err == nil {
core.RegisterDialect(v.dbType, v.get())
}
}
}
func close(engine *Engine) { func close(engine *Engine) {
engine.Close() engine.Close()
} }