refactoring db interface

This commit is contained in:
Lunny Xiao 2013-12-17 17:30:05 +08:00
parent 99c7031b50
commit 4c7e98d8d6
4 changed files with 725 additions and 701 deletions

View File

@ -1,66 +1,67 @@
package xorm package xorm
import ( import (
"errors" "errors"
"strings" "strings"
"time" "time"
) )
type mymysql struct { type mymysql struct {
mysql mysql
proto string }
raddr string
laddr string type mymysqlParser struct {
timeout time.Duration }
db string
user string func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
passwd string db := &uri{dbType: MYSQL}
pd := strings.SplitN(dataSourceName, "*", 2)
if len(pd) == 2 {
// Parse protocol part of URI
p := strings.SplitN(pd[0], ":", 2)
if len(p) != 2 {
return nil, errors.New("Wrong protocol part of URI")
}
db.proto = p[0]
options := strings.Split(p[1], ",")
db.raddr = options[0]
for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2)
var k, v string
if len(kv) == 2 {
k, v = kv[0], kv[1]
} else {
k, v = o, "true"
}
switch k {
case "laddr":
db.laddr = v
case "timeout":
to, err := time.ParseDuration(v)
if err != nil {
return nil, err
}
db.timeout = to
default:
return nil, errors.New("Unknown option: " + k)
}
}
// Remove protocol part
pd = pd[1:]
}
// Parse database part of URI
dup := strings.SplitN(pd[0], "/", 3)
if len(dup) != 3 {
return nil, errors.New("Wrong database part of URI")
}
db.dbName = dup[0]
db.user = dup[1]
db.passwd = dup[2]
return db, nil
} }
func (db *mymysql) Init(drivername, uri string) error { func (db *mymysql) Init(drivername, uri string) error {
db.mysql.base.init(drivername, uri) return db.mysql.base.init(&mymysqlParser{}, drivername, uri)
pd := strings.SplitN(uri, "*", 2)
if len(pd) == 2 {
// Parse protocol part of URI
p := strings.SplitN(pd[0], ":", 2)
if len(p) != 2 {
return errors.New("Wrong protocol part of URI")
}
db.proto = p[0]
options := strings.Split(p[1], ",")
db.raddr = options[0]
for _, o := range options[1:] {
kv := strings.SplitN(o, "=", 2)
var k, v string
if len(kv) == 2 {
k, v = kv[0], kv[1]
} else {
k, v = o, "true"
}
switch k {
case "laddr":
db.laddr = v
case "timeout":
to, err := time.ParseDuration(v)
if err != nil {
return err
}
db.timeout = to
default:
return errors.New("Unknown option: " + k)
}
}
// Remove protocol part
pd = pd[1:]
}
// Parse database part of URI
dup := strings.SplitN(pd[0], "/", 3)
if len(dup) != 3 {
return errors.New("Wrong database part of URI")
}
db.dbname = dup[0]
db.user = dup[1]
db.passwd = dup[2]
return nil
} }

520
mysql.go
View File

@ -1,311 +1,323 @@
package xorm package xorm
import ( import (
"crypto/tls" "crypto/tls"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
type base struct { type uri struct {
drivername string dbType string
dataSourceName string proto string
host string
port string
dbName string
user string
passwd string
charset string
laddr string
raddr string
timeout time.Duration
} }
func (b *base) init(drivername, dataSourceName string) { type parser interface {
b.drivername, b.dataSourceName = drivername, dataSourceName parse(driverName, dataSourceName string) (*uri, error)
}
type mysqlParser struct {
}
func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
//cfg.params = make(map[string]string)
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dataSourceName)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
uri := &uri{dbType: MYSQL}
for i, match := range matches {
switch names[i] {
case "dbname":
uri.dbName = match
}
}
return uri, nil
}
type base struct {
parser parser
driverName string
dataSourceName string
*uri
}
func (b *base) init(parser parser, drivername, dataSourceName string) (err error) {
b.parser = parser
b.driverName, b.dataSourceName = drivername, dataSourceName
b.uri, err = b.parser.parse(b.driverName, b.dataSourceName)
return
} }
type mysql struct { type mysql struct {
base base
user string net string
passwd string addr string
net string params map[string]string
addr string loc *time.Location
dbname string timeout time.Duration
params map[string]string tls *tls.Config
loc *time.Location allowAllFiles bool
timeout time.Duration allowOldPasswords bool
tls *tls.Config clientFoundRows bool
allowAllFiles bool
allowOldPasswords bool
clientFoundRows bool
}
/*func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}*/
func (cfg *mysql) parseDSN(dsn string) (err error) {
//cfg.params = make(map[string]string)
dsnPattern := regexp.MustCompile(
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
`\/(?P<dbname>.*?)` + // /dbname
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1&paramN=valueN]
matches := dsnPattern.FindStringSubmatch(dsn)
//tlsConfigRegister := make(map[string]*tls.Config)
names := dsnPattern.SubexpNames()
for i, match := range matches {
switch names[i] {
case "dbname":
cfg.dbname = match
}
}
return
} }
func (db *mysql) Init(drivername, uri string) error { func (db *mysql) Init(drivername, uri string) error {
db.base.init(drivername, uri) return db.base.init(&mysqlParser{}, drivername, uri)
return db.parseDSN(uri)
} }
func (db *mysql) SqlType(c *Column) string { func (db *mysql) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Bool: case Bool:
res = TinyInt res = TinyInt
case Serial: case Serial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = Int res = Int
case BigSerial: case BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.Nullable = false c.Nullable = false
res = BigInt res = BigInt
case Bytea: case Bytea:
res = Blob res = Blob
case TimeStampz: case TimeStampz:
res = Char res = Char
c.Length = 64 c.Length = 64
default: default:
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *mysql) SupportInsertMany() bool { func (db *mysql) SupportInsertMany() bool {
return true return true
} }
func (db *mysql) QuoteStr() string { func (db *mysql) QuoteStr() string {
return "`" return "`"
} }
func (db *mysql) SupportEngine() bool { func (db *mysql) SupportEngine() bool {
return true return true
} }
func (db *mysql) AutoIncrStr() string { func (db *mysql) AutoIncrStr() string {
return "AUTO_INCREMENT" return "AUTO_INCREMENT"
} }
func (db *mysql) SupportCharset() bool { func (db *mysql) SupportCharset() bool {
return true return true
} }
func (db *mysql) IndexOnTable() bool { func (db *mysql) IndexOnTable() bool {
return true return true
} }
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName, idxName} args := []interface{}{db.dbName, tableName, idxName}
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName, colName} args := []interface{}{db.dbName, tableName, colName}
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
return sql, args return sql, args
} }
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbName, tableName}
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args return sql, args
} }
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbName, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "COLUMN_NAME": case "COLUMN_NAME":
col.Name = strings.Trim(string(content), "` ") col.Name = strings.Trim(string(content), "` ")
case "IS_NULLABLE": case "IS_NULLABLE":
if "YES" == string(content) { if "YES" == string(content) {
col.Nullable = true col.Nullable = true
} }
case "COLUMN_DEFAULT": case "COLUMN_DEFAULT":
// add '' // add ''
col.Default = string(content) col.Default = string(content)
case "COLUMN_TYPE": case "COLUMN_TYPE":
cts := strings.Split(string(content), "(") cts := strings.Split(string(content), "(")
var len1, len2 int var len1, len2 int
if len(cts) == 2 { if len(cts) == 2 {
idx := strings.Index(cts[1], ")") idx := strings.Index(cts[1], ")")
lens := strings.Split(cts[1][0:idx], ",") lens := strings.Split(cts[1][0:idx], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if len(lens) == 2 { if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1]) len2, err = strconv.Atoi(lens[1])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
} }
colName := cts[0] colName := cts[0]
colType := strings.ToUpper(colName) colType := strings.ToUpper(colName)
col.Length = len1 col.Length = len1
col.Length2 = len2 col.Length2 = len2
if _, ok := sqlTypes[colType]; ok { if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2} col.SQLType = SQLType{colType, len1, len2}
} else { } else {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
} }
case "COLUMN_KEY": case "COLUMN_KEY":
key := string(content) key := string(content)
if key == "PRI" { if key == "PRI" {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} }
if key == "UNI" { if key == "UNI" {
//col.is //col.is
} }
case "EXTRA": case "EXTRA":
extra := string(content) extra := string(content)
if extra == "auto_increment" { if extra == "auto_increment" {
col.IsAutoIncrement = true col.IsAutoIncrement = true
} }
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *mysql) GetTables() ([]*Table, error) { func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.dbname} args := []interface{}{db.dbName}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "TABLE_NAME": case "TABLE_NAME":
table.Name = strings.Trim(string(content), "` ") table.Name = strings.Trim(string(content), "` ")
case "ENGINE": case "ENGINE":
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.dbname, tableName} args := []interface{}{db.dbName, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName, colName string var indexName, colName string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "NON_UNIQUE": case "NON_UNIQUE":
if "YES" == string(content) || string(content) == "1" { if "YES" == string(content) || string(content) == "1" {
indexType = IndexType indexType = IndexType
} else { } else {
indexType = UniqueType indexType = UniqueType
} }
case "INDEX_NAME": case "INDEX_NAME":
indexName = string(content) indexName = string(content)
case "COLUMN_NAME": case "COLUMN_NAME":
colName = strings.Trim(string(content), "` ") colName = strings.Trim(string(content), "` ")
} }
} }
if indexName == "PRIMARY" { if indexName == "PRIMARY" {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
indexName = indexName[5+len(tableName) : len(indexName)] indexName = indexName[5+len(tableName) : len(indexName)]
} }
var index *Index var index *Index
var ok bool var ok bool
if index, ok = indexes[indexName]; !ok { if index, ok = indexes[indexName]; !ok {
index = new(Index) index = new(Index)
index.Type = indexType index.Type = indexType
index.Name = indexName index.Name = indexName
indexes[indexName] = index indexes[indexName] = index
} }
index.AddColumn(colName) index.AddColumn(colName)
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,300 +1,305 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
) )
type postgres struct { type postgres struct {
base base
dbname string
} }
type values map[string]string type values map[string]string
func (vs values) Set(k, v string) { func (vs values) Set(k, v string) {
vs[k] = v vs[k] = v
} }
func (vs values) Get(k string) (v string) { func (vs values) Get(k string) (v string) {
return vs[k] return vs[k]
} }
func errorf(s string, args ...interface{}) { func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
} }
func parseOpts(name string, o values) { func parseOpts(name string, o values) {
if len(name) == 0 { if len(name) == 0 {
return return
} }
name = strings.TrimSpace(name) name = strings.TrimSpace(name)
ps := strings.Split(name, " ") ps := strings.Split(name, " ")
for _, p := range ps { for _, p := range ps {
kv := strings.Split(p, "=") kv := strings.Split(p, "=")
if len(kv) < 2 { if len(kv) < 2 {
errorf("invalid option: %q", p) errorf("invalid option: %q", p)
} }
o.Set(kv[0], kv[1]) o.Set(kv[0], kv[1])
} }
}
type postgresParser struct {
}
func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) {
db := &uri{dbType: POSTGRES}
o := make(values)
parseOpts(dataSourceName, o)
db.dbName = o.Get("dbname")
if db.dbName == "" {
return nil, errors.New("dbname is empty")
}
return db, nil
} }
func (db *postgres) Init(drivername, uri string) error { func (db *postgres) Init(drivername, uri string) error {
db.base.init(drivername, uri) return db.base.init(&postgresParser{}, drivername, uri)
o := make(values)
parseOpts(uri, o)
db.dbname = o.Get("dbname")
if db.dbname == "" {
return errors.New("dbname is empty")
}
return nil
} }
func (db *postgres) SqlType(c *Column) string { func (db *postgres) SqlType(c *Column) string {
var res string var res string
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case TinyInt: case TinyInt:
res = SmallInt res = SmallInt
case MediumInt, Int, Integer: case MediumInt, Int, Integer:
return Integer return Integer
case Serial, BigSerial: case Serial, BigSerial:
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
res = t res = t
case Binary, VarBinary: case Binary, VarBinary:
return Bytea return Bytea
case DateTime: case DateTime:
res = TimeStamp res = TimeStamp
case TimeStampz: case TimeStampz:
return "timestamp with time zone" return "timestamp with time zone"
case Float: case Float:
res = Real res = Real
case TinyText, MediumText, LongText: case TinyText, MediumText, LongText:
res = Text res = Text
case Blob, TinyBlob, MediumBlob, LongBlob: case Blob, TinyBlob, MediumBlob, LongBlob:
return Bytea return Bytea
case Double: case Double:
return "DOUBLE PRECISION" return "DOUBLE PRECISION"
default: default:
if c.IsAutoIncrement { if c.IsAutoIncrement {
return Serial return Serial
} }
res = t res = t
} }
var hasLen1 bool = (c.Length > 0) var hasLen1 bool = (c.Length > 0)
var hasLen2 bool = (c.Length2 > 0) var hasLen2 bool = (c.Length2 > 0)
if hasLen1 { if hasLen1 {
res += "(" + strconv.Itoa(c.Length) + ")" res += "(" + strconv.Itoa(c.Length) + ")"
} else if hasLen2 { } else if hasLen2 {
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
} }
return res return res
} }
func (db *postgres) SupportInsertMany() bool { func (db *postgres) SupportInsertMany() bool {
return true return true
} }
func (db *postgres) QuoteStr() string { func (db *postgres) QuoteStr() string {
return "\"" return "\""
} }
func (db *postgres) AutoIncrStr() string { func (db *postgres) AutoIncrStr() string {
return "" return ""
} }
func (db *postgres) SupportEngine() bool { func (db *postgres) SupportEngine() bool {
return false return false
} }
func (db *postgres) SupportCharset() bool { func (db *postgres) SupportCharset() bool {
return false return false
} }
func (db *postgres) IndexOnTable() bool { func (db *postgres) IndexOnTable() bool {
return false return false
} }
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{tableName, idxName} args := []interface{}{tableName, idxName}
return `SELECT indexname FROM pg_indexes ` + return `SELECT indexname FROM pg_indexes ` +
`WHERE tablename = ? AND indexname = ?`, args `WHERE tablename = ? AND indexname = ?`, args
} }
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
} }
func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName, colName} args := []interface{}{tableName, colName}
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args " AND column_name = ?", args
} }
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, record := range res { for _, record := range res {
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "column_name": case "column_name":
col.Name = strings.Trim(string(content), `" `) col.Name = strings.Trim(string(content), `" `)
case "column_default": case "column_default":
if strings.HasPrefix(string(content), "nextval") { if strings.HasPrefix(string(content), "nextval") {
col.IsPrimaryKey = true col.IsPrimaryKey = true
} else { } else {
col.Default = string(content) col.Default = string(content)
} }
case "is_nullable": case "is_nullable":
if string(content) == "YES" { if string(content) == "YES" {
col.Nullable = true col.Nullable = true
} else { } else {
col.Nullable = false col.Nullable = false
} }
case "data_type": case "data_type":
ct := string(content) ct := string(content)
switch ct { switch ct {
case "character varying", "character": case "character varying", "character":
col.SQLType = SQLType{Varchar, 0, 0} col.SQLType = SQLType{Varchar, 0, 0}
case "timestamp without time zone": case "timestamp without time zone":
col.SQLType = SQLType{DateTime, 0, 0} col.SQLType = SQLType{DateTime, 0, 0}
case "timestamp with time zone": case "timestamp with time zone":
col.SQLType = SQLType{TimeStampz, 0, 0} col.SQLType = SQLType{TimeStampz, 0, 0}
case "double precision": case "double precision":
col.SQLType = SQLType{Double, 0, 0} col.SQLType = SQLType{Double, 0, 0}
case "boolean": case "boolean":
col.SQLType = SQLType{Bool, 0, 0} col.SQLType = SQLType{Bool, 0, 0}
case "time without time zone": case "time without time zone":
col.SQLType = SQLType{Time, 0, 0} col.SQLType = SQLType{Time, 0, 0}
default: default:
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
} }
if _, ok := sqlTypes[col.SQLType.Name]; !ok { if _, ok := sqlTypes[col.SQLType.Name]; !ok {
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
} }
case "character_maximum_length": case "character_maximum_length":
i, err := strconv.Atoi(string(content)) i, err := strconv.Atoi(string(content))
if err != nil { if err != nil {
return nil, nil, errors.New("retrieve length error") return nil, nil, errors.New("retrieve length error")
} }
col.Length = i col.Length = i
case "numeric_precision": case "numeric_precision":
case "numeric_precision_radix": case "numeric_precision_radix":
} }
} }
if col.SQLType.IsText() { if col.SQLType.IsText() {
if col.Default != "" { if col.Default != "" {
col.Default = "'" + col.Default + "'" col.Default = "'" + col.Default + "'"
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *postgres) GetTables() ([]*Table, error) { func (db *postgres) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT tablename FROM pg_tables where schemaname = 'public'" s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "tablename": case "tablename":
table.Name = string(content) table.Name = string(content)
} }
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var indexType int var indexType int
var indexName string var indexName string
var colNames []string var colNames []string
for name, content := range record { for name, content := range record {
switch name { switch name {
case "indexname": case "indexname":
indexName = strings.Trim(string(content), `" `) indexName = strings.Trim(string(content), `" `)
case "indexdef": case "indexdef":
c := string(content) c := string(content)
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
indexType = UniqueType indexType = UniqueType
} else { } else {
indexType = IndexType indexType = IndexType
} }
cs := strings.Split(c, "(") cs := strings.Split(c, "(")
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
} }
} }
if strings.HasSuffix(indexName, "_pkey") { if strings.HasSuffix(indexName, "_pkey") {
continue continue
} }
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
newIdxName := indexName[5+len(tableName) : len(indexName)] newIdxName := indexName[5+len(tableName) : len(indexName)]
if newIdxName != "" { if newIdxName != "" {
indexName = newIdxName indexName = newIdxName
} }
} }
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
for _, colName := range colNames { for _, colName := range colNames {
index.Cols = append(index.Cols, strings.Trim(colName, `" `)) index.Cols = append(index.Cols, strings.Trim(colName, `" `))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }

View File

@ -1,223 +1,229 @@
package xorm package xorm
import ( import (
"database/sql" "database/sql"
"strings" "strings"
) )
type sqlite3 struct { type sqlite3 struct {
base base
}
type sqlite3Parser struct {
}
func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) {
return &uri{dbType: SQLITE, dbName: dataSourceName}, nil
} }
func (db *sqlite3) Init(drivername, dataSourceName string) error { func (db *sqlite3) Init(drivername, dataSourceName string) error {
db.base.init(drivername, dataSourceName) return db.base.init(&sqlite3Parser{}, drivername, dataSourceName)
return nil
} }
func (db *sqlite3) SqlType(c *Column) string { func (db *sqlite3) SqlType(c *Column) string {
switch t := c.SQLType.Name; t { switch t := c.SQLType.Name; t {
case Date, DateTime, TimeStamp, Time: case Date, DateTime, TimeStamp, Time:
return Numeric return Numeric
case TimeStampz: case TimeStampz:
return Text return Text
case Char, Varchar, TinyText, Text, MediumText, LongText: case Char, Varchar, TinyText, Text, MediumText, LongText:
return Text return Text
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
return Integer return Integer
case Float, Double, Real: case Float, Double, Real:
return Real return Real
case Decimal, Numeric: case Decimal, Numeric:
return Numeric return Numeric
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
return Blob return Blob
case Serial, BigSerial: case Serial, BigSerial:
c.IsPrimaryKey = true c.IsPrimaryKey = true
c.IsAutoIncrement = true c.IsAutoIncrement = true
c.Nullable = false c.Nullable = false
return Integer return Integer
default: default:
return t return t
} }
} }
func (db *sqlite3) SupportInsertMany() bool { func (db *sqlite3) SupportInsertMany() bool {
return true return true
} }
func (db *sqlite3) QuoteStr() string { func (db *sqlite3) QuoteStr() string {
return "`" return "`"
} }
func (db *sqlite3) AutoIncrStr() string { func (db *sqlite3) AutoIncrStr() string {
return "AUTOINCREMENT" return "AUTOINCREMENT"
} }
func (db *sqlite3) SupportEngine() bool { func (db *sqlite3) SupportEngine() bool {
return false return false
} }
func (db *sqlite3) SupportCharset() bool { func (db *sqlite3) SupportCharset() bool {
return false return false
} }
func (db *sqlite3) IndexOnTable() bool { func (db *sqlite3) IndexOnTable() bool {
return false return false
} }
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
args := []interface{}{idxName} args := []interface{}{idxName}
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
} }
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
} }
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
args := []interface{}{tableName} args := []interface{}{tableName}
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
return sql, args return sql, args
} }
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var sql string var sql string
for _, record := range res { for _, record := range res {
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colCreates := strings.Split(sql[nStart+1:nEnd], ",") colCreates := strings.Split(sql[nStart+1:nEnd], ",")
cols := make(map[string]*Column) cols := make(map[string]*Column)
colSeq := make([]string, 0) colSeq := make([]string, 0)
for _, colStr := range colCreates { for _, colStr := range colCreates {
fields := strings.Fields(strings.TrimSpace(colStr)) fields := strings.Fields(strings.TrimSpace(colStr))
col := new(Column) col := new(Column)
col.Indexes = make(map[string]bool) col.Indexes = make(map[string]bool)
col.Nullable = true col.Nullable = true
for idx, field := range fields { for idx, field := range fields {
if idx == 0 { if idx == 0 {
col.Name = strings.Trim(field, "`[] ") col.Name = strings.Trim(field, "`[] ")
continue continue
} else if idx == 1 { } else if idx == 1 {
col.SQLType = SQLType{field, 0, 0} col.SQLType = SQLType{field, 0, 0}
} }
switch field { switch field {
case "PRIMARY": case "PRIMARY":
col.IsPrimaryKey = true col.IsPrimaryKey = true
case "AUTOINCREMENT": case "AUTOINCREMENT":
col.IsAutoIncrement = true col.IsAutoIncrement = true
case "NULL": case "NULL":
if fields[idx-1] == "NOT" { if fields[idx-1] == "NOT" {
col.Nullable = false col.Nullable = false
} else { } else {
col.Nullable = true col.Nullable = true
} }
} }
} }
cols[col.Name] = col cols[col.Name] = col
colSeq = append(colSeq, col.Name) colSeq = append(colSeq, col.Name)
} }
return colSeq, cols, nil return colSeq, cols, nil
} }
func (db *sqlite3) GetTables() ([]*Table, error) { func (db *sqlite3) GetTables() ([]*Table, error) {
args := []interface{}{} args := []interface{}{}
s := "SELECT name FROM sqlite_master WHERE type='table'" s := "SELECT name FROM sqlite_master WHERE type='table'"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tables := make([]*Table, 0) tables := make([]*Table, 0)
for _, record := range res { for _, record := range res {
table := new(Table) table := new(Table)
for name, content := range record { for name, content := range record {
switch name { switch name {
case "name": case "name":
table.Name = string(content) table.Name = string(content)
} }
} }
if table.Name == "sqlite_sequence" { if table.Name == "sqlite_sequence" {
continue continue
} }
tables = append(tables, table) tables = append(tables, table)
} }
return tables, nil return tables, nil
} }
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{tableName} args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName) cnn, err := sql.Open(db.driverName, db.dataSourceName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer cnn.Close() defer cnn.Close()
res, err := query(cnn, s, args...) res, err := query(cnn, s, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
indexes := make(map[string]*Index, 0) indexes := make(map[string]*Index, 0)
for _, record := range res { for _, record := range res {
var sql string var sql string
index := new(Index) index := new(Index)
for name, content := range record { for name, content := range record {
if name == "sql" { if name == "sql" {
sql = string(content) sql = string(content)
} }
} }
nNStart := strings.Index(sql, "INDEX") nNStart := strings.Index(sql, "INDEX")
nNEnd := strings.Index(sql, "ON") nNEnd := strings.Index(sql, "ON")
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
//fmt.Println(indexName) //fmt.Println(indexName)
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
index.Name = indexName[5+len(tableName) : len(indexName)] index.Name = indexName[5+len(tableName) : len(indexName)]
} else { } else {
index.Name = indexName index.Name = indexName
} }
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
index.Type = UniqueType index.Type = UniqueType
} else { } else {
index.Type = IndexType index.Type = IndexType
} }
nStart := strings.Index(sql, "(") nStart := strings.Index(sql, "(")
nEnd := strings.Index(sql, ")") nEnd := strings.Index(sql, ")")
colIndexes := strings.Split(sql[nStart+1:nEnd], ",") colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
index.Cols = make([]string, 0) index.Cols = make([]string, 0)
for _, col := range colIndexes { for _, col := range colIndexes {
index.Cols = append(index.Cols, strings.Trim(col, "` []")) index.Cols = append(index.Cols, strings.Trim(col, "` []"))
} }
indexes[index.Name] = index indexes[index.Name] = index
} }
return indexes, nil return indexes, nil
} }