refactoring db interface
This commit is contained in:
parent
99c7031b50
commit
4c7e98d8d6
113
mymysql.go
113
mymysql.go
|
@ -1,66 +1,67 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mymysql struct {
|
type mymysql struct {
|
||||||
mysql
|
mysql
|
||||||
proto string
|
}
|
||||||
raddr string
|
|
||||||
laddr string
|
type mymysqlParser struct {
|
||||||
timeout time.Duration
|
}
|
||||||
db string
|
|
||||||
user string
|
func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
passwd string
|
db := &uri{dbType: MYSQL}
|
||||||
|
|
||||||
|
pd := strings.SplitN(dataSourceName, "*", 2)
|
||||||
|
if len(pd) == 2 {
|
||||||
|
// Parse protocol part of URI
|
||||||
|
p := strings.SplitN(pd[0], ":", 2)
|
||||||
|
if len(p) != 2 {
|
||||||
|
return nil, errors.New("Wrong protocol part of URI")
|
||||||
|
}
|
||||||
|
db.proto = p[0]
|
||||||
|
options := strings.Split(p[1], ",")
|
||||||
|
db.raddr = options[0]
|
||||||
|
for _, o := range options[1:] {
|
||||||
|
kv := strings.SplitN(o, "=", 2)
|
||||||
|
var k, v string
|
||||||
|
if len(kv) == 2 {
|
||||||
|
k, v = kv[0], kv[1]
|
||||||
|
} else {
|
||||||
|
k, v = o, "true"
|
||||||
|
}
|
||||||
|
switch k {
|
||||||
|
case "laddr":
|
||||||
|
db.laddr = v
|
||||||
|
case "timeout":
|
||||||
|
to, err := time.ParseDuration(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
db.timeout = to
|
||||||
|
default:
|
||||||
|
return nil, errors.New("Unknown option: " + k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove protocol part
|
||||||
|
pd = pd[1:]
|
||||||
|
}
|
||||||
|
// Parse database part of URI
|
||||||
|
dup := strings.SplitN(pd[0], "/", 3)
|
||||||
|
if len(dup) != 3 {
|
||||||
|
return nil, errors.New("Wrong database part of URI")
|
||||||
|
}
|
||||||
|
db.dbName = dup[0]
|
||||||
|
db.user = dup[1]
|
||||||
|
db.passwd = dup[2]
|
||||||
|
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mymysql) Init(drivername, uri string) error {
|
func (db *mymysql) Init(drivername, uri string) error {
|
||||||
db.mysql.base.init(drivername, uri)
|
return db.mysql.base.init(&mymysqlParser{}, drivername, uri)
|
||||||
pd := strings.SplitN(uri, "*", 2)
|
|
||||||
if len(pd) == 2 {
|
|
||||||
// Parse protocol part of URI
|
|
||||||
p := strings.SplitN(pd[0], ":", 2)
|
|
||||||
if len(p) != 2 {
|
|
||||||
return errors.New("Wrong protocol part of URI")
|
|
||||||
}
|
|
||||||
db.proto = p[0]
|
|
||||||
options := strings.Split(p[1], ",")
|
|
||||||
db.raddr = options[0]
|
|
||||||
for _, o := range options[1:] {
|
|
||||||
kv := strings.SplitN(o, "=", 2)
|
|
||||||
var k, v string
|
|
||||||
if len(kv) == 2 {
|
|
||||||
k, v = kv[0], kv[1]
|
|
||||||
} else {
|
|
||||||
k, v = o, "true"
|
|
||||||
}
|
|
||||||
switch k {
|
|
||||||
case "laddr":
|
|
||||||
db.laddr = v
|
|
||||||
case "timeout":
|
|
||||||
to, err := time.ParseDuration(v)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
db.timeout = to
|
|
||||||
default:
|
|
||||||
return errors.New("Unknown option: " + k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Remove protocol part
|
|
||||||
pd = pd[1:]
|
|
||||||
}
|
|
||||||
// Parse database part of URI
|
|
||||||
dup := strings.SplitN(pd[0], "/", 3)
|
|
||||||
if len(dup) != 3 {
|
|
||||||
return errors.New("Wrong database part of URI")
|
|
||||||
}
|
|
||||||
db.dbname = dup[0]
|
|
||||||
db.user = dup[1]
|
|
||||||
db.passwd = dup[2]
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
520
mysql.go
520
mysql.go
|
@ -1,311 +1,323 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type base struct {
|
type uri struct {
|
||||||
drivername string
|
dbType string
|
||||||
dataSourceName string
|
proto string
|
||||||
|
host string
|
||||||
|
port string
|
||||||
|
dbName string
|
||||||
|
user string
|
||||||
|
passwd string
|
||||||
|
charset string
|
||||||
|
laddr string
|
||||||
|
raddr string
|
||||||
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *base) init(drivername, dataSourceName string) {
|
type parser interface {
|
||||||
b.drivername, b.dataSourceName = drivername, dataSourceName
|
parse(driverName, dataSourceName string) (*uri, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mysqlParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
//cfg.params = make(map[string]string)
|
||||||
|
dsnPattern := regexp.MustCompile(
|
||||||
|
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
||||||
|
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
||||||
|
`\/(?P<dbname>.*?)` + // /dbname
|
||||||
|
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
||||||
|
matches := dsnPattern.FindStringSubmatch(dataSourceName)
|
||||||
|
//tlsConfigRegister := make(map[string]*tls.Config)
|
||||||
|
names := dsnPattern.SubexpNames()
|
||||||
|
|
||||||
|
uri := &uri{dbType: MYSQL}
|
||||||
|
|
||||||
|
for i, match := range matches {
|
||||||
|
switch names[i] {
|
||||||
|
case "dbname":
|
||||||
|
uri.dbName = match
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return uri, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type base struct {
|
||||||
|
parser parser
|
||||||
|
driverName string
|
||||||
|
dataSourceName string
|
||||||
|
*uri
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *base) init(parser parser, drivername, dataSourceName string) (err error) {
|
||||||
|
b.parser = parser
|
||||||
|
b.driverName, b.dataSourceName = drivername, dataSourceName
|
||||||
|
b.uri, err = b.parser.parse(b.driverName, b.dataSourceName)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type mysql struct {
|
type mysql struct {
|
||||||
base
|
base
|
||||||
user string
|
net string
|
||||||
passwd string
|
addr string
|
||||||
net string
|
params map[string]string
|
||||||
addr string
|
loc *time.Location
|
||||||
dbname string
|
timeout time.Duration
|
||||||
params map[string]string
|
tls *tls.Config
|
||||||
loc *time.Location
|
allowAllFiles bool
|
||||||
timeout time.Duration
|
allowOldPasswords bool
|
||||||
tls *tls.Config
|
clientFoundRows bool
|
||||||
allowAllFiles bool
|
|
||||||
allowOldPasswords bool
|
|
||||||
clientFoundRows bool
|
|
||||||
}
|
|
||||||
|
|
||||||
/*func readBool(input string) (value bool, valid bool) {
|
|
||||||
switch input {
|
|
||||||
case "1", "true", "TRUE", "True":
|
|
||||||
return true, true
|
|
||||||
case "0", "false", "FALSE", "False":
|
|
||||||
return false, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not a valid bool value
|
|
||||||
return
|
|
||||||
}*/
|
|
||||||
|
|
||||||
func (cfg *mysql) parseDSN(dsn string) (err error) {
|
|
||||||
//cfg.params = make(map[string]string)
|
|
||||||
dsnPattern := regexp.MustCompile(
|
|
||||||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
|
||||||
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
|
||||||
`\/(?P<dbname>.*?)` + // /dbname
|
|
||||||
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
|
||||||
matches := dsnPattern.FindStringSubmatch(dsn)
|
|
||||||
//tlsConfigRegister := make(map[string]*tls.Config)
|
|
||||||
names := dsnPattern.SubexpNames()
|
|
||||||
|
|
||||||
for i, match := range matches {
|
|
||||||
switch names[i] {
|
|
||||||
case "dbname":
|
|
||||||
cfg.dbname = match
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) Init(drivername, uri string) error {
|
func (db *mysql) Init(drivername, uri string) error {
|
||||||
db.base.init(drivername, uri)
|
return db.base.init(&mysqlParser{}, drivername, uri)
|
||||||
return db.parseDSN(uri)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SqlType(c *Column) string {
|
func (db *mysql) SqlType(c *Column) string {
|
||||||
var res string
|
var res string
|
||||||
switch t := c.SQLType.Name; t {
|
switch t := c.SQLType.Name; t {
|
||||||
case Bool:
|
case Bool:
|
||||||
res = TinyInt
|
res = TinyInt
|
||||||
case Serial:
|
case Serial:
|
||||||
c.IsAutoIncrement = true
|
c.IsAutoIncrement = true
|
||||||
c.IsPrimaryKey = true
|
c.IsPrimaryKey = true
|
||||||
c.Nullable = false
|
c.Nullable = false
|
||||||
res = Int
|
res = Int
|
||||||
case BigSerial:
|
case BigSerial:
|
||||||
c.IsAutoIncrement = true
|
c.IsAutoIncrement = true
|
||||||
c.IsPrimaryKey = true
|
c.IsPrimaryKey = true
|
||||||
c.Nullable = false
|
c.Nullable = false
|
||||||
res = BigInt
|
res = BigInt
|
||||||
case Bytea:
|
case Bytea:
|
||||||
res = Blob
|
res = Blob
|
||||||
case TimeStampz:
|
case TimeStampz:
|
||||||
res = Char
|
res = Char
|
||||||
c.Length = 64
|
c.Length = 64
|
||||||
default:
|
default:
|
||||||
res = t
|
res = t
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasLen1 bool = (c.Length > 0)
|
var hasLen1 bool = (c.Length > 0)
|
||||||
var hasLen2 bool = (c.Length2 > 0)
|
var hasLen2 bool = (c.Length2 > 0)
|
||||||
if hasLen1 {
|
if hasLen1 {
|
||||||
res += "(" + strconv.Itoa(c.Length) + ")"
|
res += "(" + strconv.Itoa(c.Length) + ")"
|
||||||
} else if hasLen2 {
|
} else if hasLen2 {
|
||||||
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
|
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SupportInsertMany() bool {
|
func (db *mysql) SupportInsertMany() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) QuoteStr() string {
|
func (db *mysql) QuoteStr() string {
|
||||||
return "`"
|
return "`"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SupportEngine() bool {
|
func (db *mysql) SupportEngine() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) AutoIncrStr() string {
|
func (db *mysql) AutoIncrStr() string {
|
||||||
return "AUTO_INCREMENT"
|
return "AUTO_INCREMENT"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) SupportCharset() bool {
|
func (db *mysql) SupportCharset() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) IndexOnTable() bool {
|
func (db *mysql) IndexOnTable() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
||||||
args := []interface{}{db.dbname, tableName, idxName}
|
args := []interface{}{db.dbName, tableName, idxName}
|
||||||
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
|
sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`"
|
||||||
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
|
sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
||||||
args := []interface{}{db.dbname, tableName, colName}
|
args := []interface{}{db.dbName, tableName, colName}
|
||||||
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
|
func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
|
||||||
args := []interface{}{db.dbname, tableName}
|
args := []interface{}{db.dbName, tableName}
|
||||||
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
||||||
args := []interface{}{db.dbname, tableName}
|
args := []interface{}{db.dbName, tableName}
|
||||||
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
|
||||||
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
cols := make(map[string]*Column)
|
cols := make(map[string]*Column)
|
||||||
colSeq := make([]string, 0)
|
colSeq := make([]string, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
col := new(Column)
|
col := new(Column)
|
||||||
col.Indexes = make(map[string]bool)
|
col.Indexes = make(map[string]bool)
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "COLUMN_NAME":
|
case "COLUMN_NAME":
|
||||||
col.Name = strings.Trim(string(content), "` ")
|
col.Name = strings.Trim(string(content), "` ")
|
||||||
case "IS_NULLABLE":
|
case "IS_NULLABLE":
|
||||||
if "YES" == string(content) {
|
if "YES" == string(content) {
|
||||||
col.Nullable = true
|
col.Nullable = true
|
||||||
}
|
}
|
||||||
case "COLUMN_DEFAULT":
|
case "COLUMN_DEFAULT":
|
||||||
// add ''
|
// add ''
|
||||||
col.Default = string(content)
|
col.Default = string(content)
|
||||||
case "COLUMN_TYPE":
|
case "COLUMN_TYPE":
|
||||||
cts := strings.Split(string(content), "(")
|
cts := strings.Split(string(content), "(")
|
||||||
var len1, len2 int
|
var len1, len2 int
|
||||||
if len(cts) == 2 {
|
if len(cts) == 2 {
|
||||||
idx := strings.Index(cts[1], ")")
|
idx := strings.Index(cts[1], ")")
|
||||||
lens := strings.Split(cts[1][0:idx], ",")
|
lens := strings.Split(cts[1][0:idx], ",")
|
||||||
len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
|
len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if len(lens) == 2 {
|
if len(lens) == 2 {
|
||||||
len2, err = strconv.Atoi(lens[1])
|
len2, err = strconv.Atoi(lens[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
colName := cts[0]
|
colName := cts[0]
|
||||||
colType := strings.ToUpper(colName)
|
colType := strings.ToUpper(colName)
|
||||||
col.Length = len1
|
col.Length = len1
|
||||||
col.Length2 = len2
|
col.Length2 = len2
|
||||||
if _, ok := sqlTypes[colType]; ok {
|
if _, ok := sqlTypes[colType]; ok {
|
||||||
col.SQLType = SQLType{colType, len1, len2}
|
col.SQLType = SQLType{colType, len1, len2}
|
||||||
} else {
|
} else {
|
||||||
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
|
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
|
||||||
}
|
}
|
||||||
case "COLUMN_KEY":
|
case "COLUMN_KEY":
|
||||||
key := string(content)
|
key := string(content)
|
||||||
if key == "PRI" {
|
if key == "PRI" {
|
||||||
col.IsPrimaryKey = true
|
col.IsPrimaryKey = true
|
||||||
}
|
}
|
||||||
if key == "UNI" {
|
if key == "UNI" {
|
||||||
//col.is
|
//col.is
|
||||||
}
|
}
|
||||||
case "EXTRA":
|
case "EXTRA":
|
||||||
extra := string(content)
|
extra := string(content)
|
||||||
if extra == "auto_increment" {
|
if extra == "auto_increment" {
|
||||||
col.IsAutoIncrement = true
|
col.IsAutoIncrement = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if col.SQLType.IsText() {
|
if col.SQLType.IsText() {
|
||||||
if col.Default != "" {
|
if col.Default != "" {
|
||||||
col.Default = "'" + col.Default + "'"
|
col.Default = "'" + col.Default + "'"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cols[col.Name] = col
|
cols[col.Name] = col
|
||||||
colSeq = append(colSeq, col.Name)
|
colSeq = append(colSeq, col.Name)
|
||||||
}
|
}
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetTables() ([]*Table, error) {
|
func (db *mysql) GetTables() ([]*Table, error) {
|
||||||
args := []interface{}{db.dbname}
|
args := []interface{}{db.dbName}
|
||||||
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
|
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tables := make([]*Table, 0)
|
tables := make([]*Table, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
table := new(Table)
|
table := new(Table)
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "TABLE_NAME":
|
case "TABLE_NAME":
|
||||||
table.Name = strings.Trim(string(content), "` ")
|
table.Name = strings.Trim(string(content), "` ")
|
||||||
case "ENGINE":
|
case "ENGINE":
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tables = append(tables, table)
|
tables = append(tables, table)
|
||||||
}
|
}
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
|
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
|
||||||
args := []interface{}{db.dbname, tableName}
|
args := []interface{}{db.dbName, tableName}
|
||||||
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
indexes := make(map[string]*Index, 0)
|
indexes := make(map[string]*Index, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
var indexType int
|
var indexType int
|
||||||
var indexName, colName string
|
var indexName, colName string
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "NON_UNIQUE":
|
case "NON_UNIQUE":
|
||||||
if "YES" == string(content) || string(content) == "1" {
|
if "YES" == string(content) || string(content) == "1" {
|
||||||
indexType = IndexType
|
indexType = IndexType
|
||||||
} else {
|
} else {
|
||||||
indexType = UniqueType
|
indexType = UniqueType
|
||||||
}
|
}
|
||||||
case "INDEX_NAME":
|
case "INDEX_NAME":
|
||||||
indexName = string(content)
|
indexName = string(content)
|
||||||
case "COLUMN_NAME":
|
case "COLUMN_NAME":
|
||||||
colName = strings.Trim(string(content), "` ")
|
colName = strings.Trim(string(content), "` ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if indexName == "PRIMARY" {
|
if indexName == "PRIMARY" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
|
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
|
||||||
indexName = indexName[5+len(tableName) : len(indexName)]
|
indexName = indexName[5+len(tableName) : len(indexName)]
|
||||||
}
|
}
|
||||||
|
|
||||||
var index *Index
|
var index *Index
|
||||||
var ok bool
|
var ok bool
|
||||||
if index, ok = indexes[indexName]; !ok {
|
if index, ok = indexes[indexName]; !ok {
|
||||||
index = new(Index)
|
index = new(Index)
|
||||||
index.Type = indexType
|
index.Type = indexType
|
||||||
index.Name = indexName
|
index.Name = indexName
|
||||||
indexes[indexName] = index
|
indexes[indexName] = index
|
||||||
}
|
}
|
||||||
index.AddColumn(colName)
|
index.AddColumn(colName)
|
||||||
}
|
}
|
||||||
return indexes, nil
|
return indexes, nil
|
||||||
}
|
}
|
||||||
|
|
459
postgres.go
459
postgres.go
|
@ -1,300 +1,305 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
base
|
base
|
||||||
dbname string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type values map[string]string
|
type values map[string]string
|
||||||
|
|
||||||
func (vs values) Set(k, v string) {
|
func (vs values) Set(k, v string) {
|
||||||
vs[k] = v
|
vs[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vs values) Get(k string) (v string) {
|
func (vs values) Get(k string) (v string) {
|
||||||
return vs[k]
|
return vs[k]
|
||||||
}
|
}
|
||||||
|
|
||||||
func errorf(s string, args ...interface{}) {
|
func errorf(s string, args ...interface{}) {
|
||||||
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
|
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseOpts(name string, o values) {
|
func parseOpts(name string, o values) {
|
||||||
if len(name) == 0 {
|
if len(name) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
name = strings.TrimSpace(name)
|
name = strings.TrimSpace(name)
|
||||||
|
|
||||||
ps := strings.Split(name, " ")
|
ps := strings.Split(name, " ")
|
||||||
for _, p := range ps {
|
for _, p := range ps {
|
||||||
kv := strings.Split(p, "=")
|
kv := strings.Split(p, "=")
|
||||||
if len(kv) < 2 {
|
if len(kv) < 2 {
|
||||||
errorf("invalid option: %q", p)
|
errorf("invalid option: %q", p)
|
||||||
}
|
}
|
||||||
o.Set(kv[0], kv[1])
|
o.Set(kv[0], kv[1])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type postgresParser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
db := &uri{dbType: POSTGRES}
|
||||||
|
o := make(values)
|
||||||
|
parseOpts(dataSourceName, o)
|
||||||
|
|
||||||
|
db.dbName = o.Get("dbname")
|
||||||
|
if db.dbName == "" {
|
||||||
|
return nil, errors.New("dbname is empty")
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) Init(drivername, uri string) error {
|
func (db *postgres) Init(drivername, uri string) error {
|
||||||
db.base.init(drivername, uri)
|
return db.base.init(&postgresParser{}, drivername, uri)
|
||||||
|
|
||||||
o := make(values)
|
|
||||||
parseOpts(uri, o)
|
|
||||||
|
|
||||||
db.dbname = o.Get("dbname")
|
|
||||||
if db.dbname == "" {
|
|
||||||
return errors.New("dbname is empty")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) SqlType(c *Column) string {
|
func (db *postgres) SqlType(c *Column) string {
|
||||||
var res string
|
var res string
|
||||||
switch t := c.SQLType.Name; t {
|
switch t := c.SQLType.Name; t {
|
||||||
case TinyInt:
|
case TinyInt:
|
||||||
res = SmallInt
|
res = SmallInt
|
||||||
case MediumInt, Int, Integer:
|
case MediumInt, Int, Integer:
|
||||||
return Integer
|
return Integer
|
||||||
case Serial, BigSerial:
|
case Serial, BigSerial:
|
||||||
c.IsAutoIncrement = true
|
c.IsAutoIncrement = true
|
||||||
c.Nullable = false
|
c.Nullable = false
|
||||||
res = t
|
res = t
|
||||||
case Binary, VarBinary:
|
case Binary, VarBinary:
|
||||||
return Bytea
|
return Bytea
|
||||||
case DateTime:
|
case DateTime:
|
||||||
res = TimeStamp
|
res = TimeStamp
|
||||||
case TimeStampz:
|
case TimeStampz:
|
||||||
return "timestamp with time zone"
|
return "timestamp with time zone"
|
||||||
case Float:
|
case Float:
|
||||||
res = Real
|
res = Real
|
||||||
case TinyText, MediumText, LongText:
|
case TinyText, MediumText, LongText:
|
||||||
res = Text
|
res = Text
|
||||||
case Blob, TinyBlob, MediumBlob, LongBlob:
|
case Blob, TinyBlob, MediumBlob, LongBlob:
|
||||||
return Bytea
|
return Bytea
|
||||||
case Double:
|
case Double:
|
||||||
return "DOUBLE PRECISION"
|
return "DOUBLE PRECISION"
|
||||||
default:
|
default:
|
||||||
if c.IsAutoIncrement {
|
if c.IsAutoIncrement {
|
||||||
return Serial
|
return Serial
|
||||||
}
|
}
|
||||||
res = t
|
res = t
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasLen1 bool = (c.Length > 0)
|
var hasLen1 bool = (c.Length > 0)
|
||||||
var hasLen2 bool = (c.Length2 > 0)
|
var hasLen2 bool = (c.Length2 > 0)
|
||||||
if hasLen1 {
|
if hasLen1 {
|
||||||
res += "(" + strconv.Itoa(c.Length) + ")"
|
res += "(" + strconv.Itoa(c.Length) + ")"
|
||||||
} else if hasLen2 {
|
} else if hasLen2 {
|
||||||
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
|
res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")"
|
||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) SupportInsertMany() bool {
|
func (db *postgres) SupportInsertMany() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) QuoteStr() string {
|
func (db *postgres) QuoteStr() string {
|
||||||
return "\""
|
return "\""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) AutoIncrStr() string {
|
func (db *postgres) AutoIncrStr() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) SupportEngine() bool {
|
func (db *postgres) SupportEngine() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) SupportCharset() bool {
|
func (db *postgres) SupportCharset() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) IndexOnTable() bool {
|
func (db *postgres) IndexOnTable() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
||||||
args := []interface{}{tableName, idxName}
|
args := []interface{}{tableName, idxName}
|
||||||
return `SELECT indexname FROM pg_indexes ` +
|
return `SELECT indexname FROM pg_indexes ` +
|
||||||
`WHERE tablename = ? AND indexname = ?`, args
|
`WHERE tablename = ? AND indexname = ?`, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
|
func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
|
return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
||||||
args := []interface{}{tableName, colName}
|
args := []interface{}{tableName, colName}
|
||||||
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
|
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
|
||||||
" AND column_name = ?", args
|
" AND column_name = ?", args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
|
s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" +
|
||||||
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
|
", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1"
|
||||||
|
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
cols := make(map[string]*Column)
|
cols := make(map[string]*Column)
|
||||||
colSeq := make([]string, 0)
|
colSeq := make([]string, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
col := new(Column)
|
col := new(Column)
|
||||||
col.Indexes = make(map[string]bool)
|
col.Indexes = make(map[string]bool)
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "column_name":
|
case "column_name":
|
||||||
col.Name = strings.Trim(string(content), `" `)
|
col.Name = strings.Trim(string(content), `" `)
|
||||||
case "column_default":
|
case "column_default":
|
||||||
if strings.HasPrefix(string(content), "nextval") {
|
if strings.HasPrefix(string(content), "nextval") {
|
||||||
col.IsPrimaryKey = true
|
col.IsPrimaryKey = true
|
||||||
} else {
|
} else {
|
||||||
col.Default = string(content)
|
col.Default = string(content)
|
||||||
}
|
}
|
||||||
case "is_nullable":
|
case "is_nullable":
|
||||||
if string(content) == "YES" {
|
if string(content) == "YES" {
|
||||||
col.Nullable = true
|
col.Nullable = true
|
||||||
} else {
|
} else {
|
||||||
col.Nullable = false
|
col.Nullable = false
|
||||||
}
|
}
|
||||||
case "data_type":
|
case "data_type":
|
||||||
ct := string(content)
|
ct := string(content)
|
||||||
switch ct {
|
switch ct {
|
||||||
case "character varying", "character":
|
case "character varying", "character":
|
||||||
col.SQLType = SQLType{Varchar, 0, 0}
|
col.SQLType = SQLType{Varchar, 0, 0}
|
||||||
case "timestamp without time zone":
|
case "timestamp without time zone":
|
||||||
col.SQLType = SQLType{DateTime, 0, 0}
|
col.SQLType = SQLType{DateTime, 0, 0}
|
||||||
case "timestamp with time zone":
|
case "timestamp with time zone":
|
||||||
col.SQLType = SQLType{TimeStampz, 0, 0}
|
col.SQLType = SQLType{TimeStampz, 0, 0}
|
||||||
case "double precision":
|
case "double precision":
|
||||||
col.SQLType = SQLType{Double, 0, 0}
|
col.SQLType = SQLType{Double, 0, 0}
|
||||||
case "boolean":
|
case "boolean":
|
||||||
col.SQLType = SQLType{Bool, 0, 0}
|
col.SQLType = SQLType{Bool, 0, 0}
|
||||||
case "time without time zone":
|
case "time without time zone":
|
||||||
col.SQLType = SQLType{Time, 0, 0}
|
col.SQLType = SQLType{Time, 0, 0}
|
||||||
default:
|
default:
|
||||||
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
|
col.SQLType = SQLType{strings.ToUpper(ct), 0, 0}
|
||||||
}
|
}
|
||||||
if _, ok := sqlTypes[col.SQLType.Name]; !ok {
|
if _, ok := sqlTypes[col.SQLType.Name]; !ok {
|
||||||
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
|
return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct))
|
||||||
}
|
}
|
||||||
case "character_maximum_length":
|
case "character_maximum_length":
|
||||||
i, err := strconv.Atoi(string(content))
|
i, err := strconv.Atoi(string(content))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, errors.New("retrieve length error")
|
return nil, nil, errors.New("retrieve length error")
|
||||||
}
|
}
|
||||||
col.Length = i
|
col.Length = i
|
||||||
case "numeric_precision":
|
case "numeric_precision":
|
||||||
case "numeric_precision_radix":
|
case "numeric_precision_radix":
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if col.SQLType.IsText() {
|
if col.SQLType.IsText() {
|
||||||
if col.Default != "" {
|
if col.Default != "" {
|
||||||
col.Default = "'" + col.Default + "'"
|
col.Default = "'" + col.Default + "'"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cols[col.Name] = col
|
cols[col.Name] = col
|
||||||
colSeq = append(colSeq, col.Name)
|
colSeq = append(colSeq, col.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetTables() ([]*Table, error) {
|
func (db *postgres) GetTables() ([]*Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
|
s := "SELECT tablename FROM pg_tables where schemaname = 'public'"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tables := make([]*Table, 0)
|
tables := make([]*Table, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
table := new(Table)
|
table := new(Table)
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "tablename":
|
case "tablename":
|
||||||
table.Name = string(content)
|
table.Name = string(content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tables = append(tables, table)
|
tables = append(tables, table)
|
||||||
}
|
}
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
|
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
|
s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1"
|
||||||
|
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
indexes := make(map[string]*Index, 0)
|
indexes := make(map[string]*Index, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
var indexType int
|
var indexType int
|
||||||
var indexName string
|
var indexName string
|
||||||
var colNames []string
|
var colNames []string
|
||||||
|
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "indexname":
|
case "indexname":
|
||||||
indexName = strings.Trim(string(content), `" `)
|
indexName = strings.Trim(string(content), `" `)
|
||||||
case "indexdef":
|
case "indexdef":
|
||||||
c := string(content)
|
c := string(content)
|
||||||
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
|
if strings.HasPrefix(c, "CREATE UNIQUE INDEX") {
|
||||||
indexType = UniqueType
|
indexType = UniqueType
|
||||||
} else {
|
} else {
|
||||||
indexType = IndexType
|
indexType = IndexType
|
||||||
}
|
}
|
||||||
cs := strings.Split(c, "(")
|
cs := strings.Split(c, "(")
|
||||||
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
|
colNames = strings.Split(cs[1][0:len(cs[1])-1], ",")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(indexName, "_pkey") {
|
if strings.HasSuffix(indexName, "_pkey") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
|
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
|
||||||
newIdxName := indexName[5+len(tableName) : len(indexName)]
|
newIdxName := indexName[5+len(tableName) : len(indexName)]
|
||||||
if newIdxName != "" {
|
if newIdxName != "" {
|
||||||
indexName = newIdxName
|
indexName = newIdxName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
|
index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)}
|
||||||
for _, colName := range colNames {
|
for _, colName := range colNames {
|
||||||
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
|
index.Cols = append(index.Cols, strings.Trim(colName, `" `))
|
||||||
}
|
}
|
||||||
indexes[index.Name] = index
|
indexes[index.Name] = index
|
||||||
}
|
}
|
||||||
return indexes, nil
|
return indexes, nil
|
||||||
}
|
}
|
||||||
|
|
334
sqlite3.go
334
sqlite3.go
|
@ -1,223 +1,229 @@
|
||||||
package xorm
|
package xorm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlite3 struct {
|
type sqlite3 struct {
|
||||||
base
|
base
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlite3Parser struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) {
|
||||||
|
return &uri{dbType: SQLITE, dbName: dataSourceName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) Init(drivername, dataSourceName string) error {
|
func (db *sqlite3) Init(drivername, dataSourceName string) error {
|
||||||
db.base.init(drivername, dataSourceName)
|
return db.base.init(&sqlite3Parser{}, drivername, dataSourceName)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SqlType(c *Column) string {
|
func (db *sqlite3) SqlType(c *Column) string {
|
||||||
switch t := c.SQLType.Name; t {
|
switch t := c.SQLType.Name; t {
|
||||||
case Date, DateTime, TimeStamp, Time:
|
case Date, DateTime, TimeStamp, Time:
|
||||||
return Numeric
|
return Numeric
|
||||||
case TimeStampz:
|
case TimeStampz:
|
||||||
return Text
|
return Text
|
||||||
case Char, Varchar, TinyText, Text, MediumText, LongText:
|
case Char, Varchar, TinyText, Text, MediumText, LongText:
|
||||||
return Text
|
return Text
|
||||||
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
|
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool:
|
||||||
return Integer
|
return Integer
|
||||||
case Float, Double, Real:
|
case Float, Double, Real:
|
||||||
return Real
|
return Real
|
||||||
case Decimal, Numeric:
|
case Decimal, Numeric:
|
||||||
return Numeric
|
return Numeric
|
||||||
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
|
case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary:
|
||||||
return Blob
|
return Blob
|
||||||
case Serial, BigSerial:
|
case Serial, BigSerial:
|
||||||
c.IsPrimaryKey = true
|
c.IsPrimaryKey = true
|
||||||
c.IsAutoIncrement = true
|
c.IsAutoIncrement = true
|
||||||
c.Nullable = false
|
c.Nullable = false
|
||||||
return Integer
|
return Integer
|
||||||
default:
|
default:
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SupportInsertMany() bool {
|
func (db *sqlite3) SupportInsertMany() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) QuoteStr() string {
|
func (db *sqlite3) QuoteStr() string {
|
||||||
return "`"
|
return "`"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) AutoIncrStr() string {
|
func (db *sqlite3) AutoIncrStr() string {
|
||||||
return "AUTOINCREMENT"
|
return "AUTOINCREMENT"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SupportEngine() bool {
|
func (db *sqlite3) SupportEngine() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) SupportCharset() bool {
|
func (db *sqlite3) SupportCharset() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) IndexOnTable() bool {
|
func (db *sqlite3) IndexOnTable() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) {
|
||||||
args := []interface{}{idxName}
|
args := []interface{}{idxName}
|
||||||
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
|
return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
|
func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
|
return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
|
sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))"
|
||||||
return sql, args
|
return sql, args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var sql string
|
var sql string
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
if name == "sql" {
|
if name == "sql" {
|
||||||
sql = string(content)
|
sql = string(content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nStart := strings.Index(sql, "(")
|
nStart := strings.Index(sql, "(")
|
||||||
nEnd := strings.Index(sql, ")")
|
nEnd := strings.Index(sql, ")")
|
||||||
colCreates := strings.Split(sql[nStart+1:nEnd], ",")
|
colCreates := strings.Split(sql[nStart+1:nEnd], ",")
|
||||||
cols := make(map[string]*Column)
|
cols := make(map[string]*Column)
|
||||||
colSeq := make([]string, 0)
|
colSeq := make([]string, 0)
|
||||||
for _, colStr := range colCreates {
|
for _, colStr := range colCreates {
|
||||||
fields := strings.Fields(strings.TrimSpace(colStr))
|
fields := strings.Fields(strings.TrimSpace(colStr))
|
||||||
col := new(Column)
|
col := new(Column)
|
||||||
col.Indexes = make(map[string]bool)
|
col.Indexes = make(map[string]bool)
|
||||||
col.Nullable = true
|
col.Nullable = true
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if idx == 0 {
|
if idx == 0 {
|
||||||
col.Name = strings.Trim(field, "`[] ")
|
col.Name = strings.Trim(field, "`[] ")
|
||||||
continue
|
continue
|
||||||
} else if idx == 1 {
|
} else if idx == 1 {
|
||||||
col.SQLType = SQLType{field, 0, 0}
|
col.SQLType = SQLType{field, 0, 0}
|
||||||
}
|
}
|
||||||
switch field {
|
switch field {
|
||||||
case "PRIMARY":
|
case "PRIMARY":
|
||||||
col.IsPrimaryKey = true
|
col.IsPrimaryKey = true
|
||||||
case "AUTOINCREMENT":
|
case "AUTOINCREMENT":
|
||||||
col.IsAutoIncrement = true
|
col.IsAutoIncrement = true
|
||||||
case "NULL":
|
case "NULL":
|
||||||
if fields[idx-1] == "NOT" {
|
if fields[idx-1] == "NOT" {
|
||||||
col.Nullable = false
|
col.Nullable = false
|
||||||
} else {
|
} else {
|
||||||
col.Nullable = true
|
col.Nullable = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cols[col.Name] = col
|
cols[col.Name] = col
|
||||||
colSeq = append(colSeq, col.Name)
|
colSeq = append(colSeq, col.Name)
|
||||||
}
|
}
|
||||||
return colSeq, cols, nil
|
return colSeq, cols, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) GetTables() ([]*Table, error) {
|
func (db *sqlite3) GetTables() ([]*Table, error) {
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
s := "SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tables := make([]*Table, 0)
|
tables := make([]*Table, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
table := new(Table)
|
table := new(Table)
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
switch name {
|
switch name {
|
||||||
case "name":
|
case "name":
|
||||||
table.Name = string(content)
|
table.Name = string(content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if table.Name == "sqlite_sequence" {
|
if table.Name == "sqlite_sequence" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tables = append(tables, table)
|
tables = append(tables, table)
|
||||||
}
|
}
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
|
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
|
||||||
args := []interface{}{tableName}
|
args := []interface{}{tableName}
|
||||||
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?"
|
||||||
cnn, err := sql.Open(db.drivername, db.dataSourceName)
|
cnn, err := sql.Open(db.driverName, db.dataSourceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer cnn.Close()
|
defer cnn.Close()
|
||||||
res, err := query(cnn, s, args...)
|
res, err := query(cnn, s, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
indexes := make(map[string]*Index, 0)
|
indexes := make(map[string]*Index, 0)
|
||||||
for _, record := range res {
|
for _, record := range res {
|
||||||
var sql string
|
var sql string
|
||||||
index := new(Index)
|
index := new(Index)
|
||||||
for name, content := range record {
|
for name, content := range record {
|
||||||
if name == "sql" {
|
if name == "sql" {
|
||||||
sql = string(content)
|
sql = string(content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nNStart := strings.Index(sql, "INDEX")
|
nNStart := strings.Index(sql, "INDEX")
|
||||||
nNEnd := strings.Index(sql, "ON")
|
nNEnd := strings.Index(sql, "ON")
|
||||||
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
|
indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []")
|
||||||
//fmt.Println(indexName)
|
//fmt.Println(indexName)
|
||||||
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
|
if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) {
|
||||||
index.Name = indexName[5+len(tableName) : len(indexName)]
|
index.Name = indexName[5+len(tableName) : len(indexName)]
|
||||||
} else {
|
} else {
|
||||||
index.Name = indexName
|
index.Name = indexName
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
|
if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") {
|
||||||
index.Type = UniqueType
|
index.Type = UniqueType
|
||||||
} else {
|
} else {
|
||||||
index.Type = IndexType
|
index.Type = IndexType
|
||||||
}
|
}
|
||||||
|
|
||||||
nStart := strings.Index(sql, "(")
|
nStart := strings.Index(sql, "(")
|
||||||
nEnd := strings.Index(sql, ")")
|
nEnd := strings.Index(sql, ")")
|
||||||
colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
|
colIndexes := strings.Split(sql[nStart+1:nEnd], ",")
|
||||||
|
|
||||||
index.Cols = make([]string, 0)
|
index.Cols = make([]string, 0)
|
||||||
for _, col := range colIndexes {
|
for _, col := range colIndexes {
|
||||||
index.Cols = append(index.Cols, strings.Trim(col, "` []"))
|
index.Cols = append(index.Cols, strings.Trim(col, "` []"))
|
||||||
}
|
}
|
||||||
indexes[index.Name] = index
|
indexes[index.Name] = index
|
||||||
}
|
}
|
||||||
|
|
||||||
return indexes, nil
|
return indexes, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue