diff --git a/mymysql.go b/mymysql.go index c664be0d..8101d4c0 100644 --- a/mymysql.go +++ b/mymysql.go @@ -1,66 +1,67 @@ package xorm import ( - "errors" - "strings" - "time" + "errors" + "strings" + "time" ) type mymysql struct { - mysql - proto string - raddr string - laddr string - timeout time.Duration - db string - user string - passwd string + mysql +} + +type mymysqlParser struct { +} + +func (p *mymysqlParser) parse(driverName, dataSourceName string) (*uri, error) { + db := &uri{dbType: MYSQL} + + pd := strings.SplitN(dataSourceName, "*", 2) + if len(pd) == 2 { + // Parse protocol part of URI + p := strings.SplitN(pd[0], ":", 2) + if len(p) != 2 { + return nil, errors.New("Wrong protocol part of URI") + } + db.proto = p[0] + options := strings.Split(p[1], ",") + db.raddr = options[0] + for _, o := range options[1:] { + kv := strings.SplitN(o, "=", 2) + var k, v string + if len(kv) == 2 { + k, v = kv[0], kv[1] + } else { + k, v = o, "true" + } + switch k { + case "laddr": + db.laddr = v + case "timeout": + to, err := time.ParseDuration(v) + if err != nil { + return nil, err + } + db.timeout = to + default: + return nil, errors.New("Unknown option: " + k) + } + } + // Remove protocol part + pd = pd[1:] + } + // Parse database part of URI + dup := strings.SplitN(pd[0], "/", 3) + if len(dup) != 3 { + return nil, errors.New("Wrong database part of URI") + } + db.dbName = dup[0] + db.user = dup[1] + db.passwd = dup[2] + + return db, nil } func (db *mymysql) Init(drivername, uri string) error { - db.mysql.base.init(drivername, uri) - pd := strings.SplitN(uri, "*", 2) - if len(pd) == 2 { - // Parse protocol part of URI - p := strings.SplitN(pd[0], ":", 2) - if len(p) != 2 { - return errors.New("Wrong protocol part of URI") - } - db.proto = p[0] - options := strings.Split(p[1], ",") - db.raddr = options[0] - for _, o := range options[1:] { - kv := strings.SplitN(o, "=", 2) - var k, v string - if len(kv) == 2 { - k, v = kv[0], kv[1] - } else { - k, v = o, "true" - } - switch k { - case "laddr": - db.laddr = v - case "timeout": - to, err := time.ParseDuration(v) - if err != nil { - return err - } - db.timeout = to - default: - return errors.New("Unknown option: " + k) - } - } - // Remove protocol part - pd = pd[1:] - } - // Parse database part of URI - dup := strings.SplitN(pd[0], "/", 3) - if len(dup) != 3 { - return errors.New("Wrong database part of URI") - } - db.dbname = dup[0] - db.user = dup[1] - db.passwd = dup[2] - - return nil + return db.mysql.base.init(&mymysqlParser{}, drivername, uri) } diff --git a/mysql.go b/mysql.go index 17e6603c..bde0186a 100644 --- a/mysql.go +++ b/mysql.go @@ -1,311 +1,323 @@ package xorm import ( - "crypto/tls" - "database/sql" - "errors" - "fmt" - "regexp" - "strconv" - "strings" - "time" + "crypto/tls" + "database/sql" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "time" ) -type base struct { - drivername string - dataSourceName string +type uri struct { + dbType string + proto string + host string + port string + dbName string + user string + passwd string + charset string + laddr string + raddr string + timeout time.Duration } -func (b *base) init(drivername, dataSourceName string) { - b.drivername, b.dataSourceName = drivername, dataSourceName +type parser interface { + parse(driverName, dataSourceName string) (*uri, error) +} + +type mysqlParser struct { +} + +func (p *mysqlParser) parse(driverName, dataSourceName string) (*uri, error) { + //cfg.params = make(map[string]string) + dsnPattern := regexp.MustCompile( + `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] + `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] + `\/(?P.*?)` + // /dbname + `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] + matches := dsnPattern.FindStringSubmatch(dataSourceName) + //tlsConfigRegister := make(map[string]*tls.Config) + names := dsnPattern.SubexpNames() + + uri := &uri{dbType: MYSQL} + + for i, match := range matches { + switch names[i] { + case "dbname": + uri.dbName = match + } + } + return uri, nil +} + +type base struct { + parser parser + driverName string + dataSourceName string + *uri +} + +func (b *base) init(parser parser, drivername, dataSourceName string) (err error) { + b.parser = parser + b.driverName, b.dataSourceName = drivername, dataSourceName + b.uri, err = b.parser.parse(b.driverName, b.dataSourceName) + return } type mysql struct { - base - user string - passwd string - net string - addr string - dbname string - params map[string]string - loc *time.Location - timeout time.Duration - tls *tls.Config - allowAllFiles bool - allowOldPasswords bool - clientFoundRows bool -} - -/*func readBool(input string) (value bool, valid bool) { - switch input { - case "1", "true", "TRUE", "True": - return true, true - case "0", "false", "FALSE", "False": - return false, true - } - - // Not a valid bool value - return -}*/ - -func (cfg *mysql) parseDSN(dsn string) (err error) { - //cfg.params = make(map[string]string) - dsnPattern := regexp.MustCompile( - `^(?:(?P.*?)(?::(?P.*))?@)?` + // [user[:password]@] - `(?:(?P[^\(]*)(?:\((?P[^\)]*)\))?)?` + // [net[(addr)]] - `\/(?P.*?)` + // /dbname - `(?:\?(?P[^\?]*))?$`) // [?param1=value1¶mN=valueN] - matches := dsnPattern.FindStringSubmatch(dsn) - //tlsConfigRegister := make(map[string]*tls.Config) - names := dsnPattern.SubexpNames() - - for i, match := range matches { - switch names[i] { - case "dbname": - cfg.dbname = match - } - } - return + base + net string + addr string + params map[string]string + loc *time.Location + timeout time.Duration + tls *tls.Config + allowAllFiles bool + allowOldPasswords bool + clientFoundRows bool } func (db *mysql) Init(drivername, uri string) error { - db.base.init(drivername, uri) - return db.parseDSN(uri) + return db.base.init(&mysqlParser{}, drivername, uri) } func (db *mysql) SqlType(c *Column) string { - var res string - switch t := c.SQLType.Name; t { - case Bool: - res = TinyInt - case Serial: - c.IsAutoIncrement = true - c.IsPrimaryKey = true - c.Nullable = false - res = Int - case BigSerial: - c.IsAutoIncrement = true - c.IsPrimaryKey = true - c.Nullable = false - res = BigInt - case Bytea: - res = Blob - case TimeStampz: - res = Char - c.Length = 64 - default: - res = t - } + var res string + switch t := c.SQLType.Name; t { + case Bool: + res = TinyInt + case Serial: + c.IsAutoIncrement = true + c.IsPrimaryKey = true + c.Nullable = false + res = Int + case BigSerial: + c.IsAutoIncrement = true + c.IsPrimaryKey = true + c.Nullable = false + res = BigInt + case Bytea: + res = Blob + case TimeStampz: + res = Char + c.Length = 64 + default: + res = t + } - var hasLen1 bool = (c.Length > 0) - var hasLen2 bool = (c.Length2 > 0) - if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" - } else if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" - } - return res + var hasLen1 bool = (c.Length > 0) + var hasLen2 bool = (c.Length2 > 0) + if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } else if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } + return res } func (db *mysql) SupportInsertMany() bool { - return true + return true } func (db *mysql) QuoteStr() string { - return "`" + return "`" } func (db *mysql) SupportEngine() bool { - return true + return true } func (db *mysql) AutoIncrStr() string { - return "AUTO_INCREMENT" + return "AUTO_INCREMENT" } func (db *mysql) SupportCharset() bool { - return true + return true } func (db *mysql) IndexOnTable() bool { - return true + return true } func (db *mysql) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{db.dbname, tableName, idxName} - sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" - sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" - return sql, args + args := []interface{}{db.dbName, tableName, idxName} + sql := "SELECT `INDEX_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS`" + sql += " WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `INDEX_NAME`=?" + return sql, args } func (db *mysql) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{db.dbname, tableName, colName} - sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" - return sql, args + args := []interface{}{db.dbName, tableName, colName} + sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" + return sql, args } func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{db.dbname, tableName} - sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" - return sql, args + args := []interface{}{db.dbName, tableName} + sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?" + return sql, args } func (db *mysql) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{db.dbname, tableName} - s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + - " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, record := range res { - col := new(Column) - col.Indexes = make(map[string]bool) - for name, content := range record { - switch name { - case "COLUMN_NAME": - col.Name = strings.Trim(string(content), "` ") - case "IS_NULLABLE": - if "YES" == string(content) { - col.Nullable = true - } - case "COLUMN_DEFAULT": - // add '' - col.Default = string(content) - case "COLUMN_TYPE": - cts := strings.Split(string(content), "(") - var len1, len2 int - if len(cts) == 2 { - idx := strings.Index(cts[1], ")") - lens := strings.Split(cts[1][0:idx], ",") - len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) - if err != nil { - return nil, nil, err - } - if len(lens) == 2 { - len2, err = strconv.Atoi(lens[1]) - if err != nil { - return nil, nil, err - } - } - } - colName := cts[0] - colType := strings.ToUpper(colName) - col.Length = len1 - col.Length2 = len2 - if _, ok := sqlTypes[colType]; ok { - col.SQLType = SQLType{colType, len1, len2} - } else { - return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) - } - case "COLUMN_KEY": - key := string(content) - if key == "PRI" { - col.IsPrimaryKey = true - } - if key == "UNI" { - //col.is - } - case "EXTRA": - extra := string(content) - if extra == "auto_increment" { - col.IsAutoIncrement = true - } - } - } - if col.SQLType.IsText() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } - return colSeq, cols, nil + args := []interface{}{db.dbName, tableName} + s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, record := range res { + col := new(Column) + col.Indexes = make(map[string]bool) + for name, content := range record { + switch name { + case "COLUMN_NAME": + col.Name = strings.Trim(string(content), "` ") + case "IS_NULLABLE": + if "YES" == string(content) { + col.Nullable = true + } + case "COLUMN_DEFAULT": + // add '' + col.Default = string(content) + case "COLUMN_TYPE": + cts := strings.Split(string(content), "(") + var len1, len2 int + if len(cts) == 2 { + idx := strings.Index(cts[1], ")") + lens := strings.Split(cts[1][0:idx], ",") + len1, err = strconv.Atoi(strings.TrimSpace(lens[0])) + if err != nil { + return nil, nil, err + } + if len(lens) == 2 { + len2, err = strconv.Atoi(lens[1]) + if err != nil { + return nil, nil, err + } + } + } + colName := cts[0] + colType := strings.ToUpper(colName) + col.Length = len1 + col.Length2 = len2 + if _, ok := sqlTypes[colType]; ok { + col.SQLType = SQLType{colType, len1, len2} + } else { + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", colType)) + } + case "COLUMN_KEY": + key := string(content) + if key == "PRI" { + col.IsPrimaryKey = true + } + if key == "UNI" { + //col.is + } + case "EXTRA": + extra := string(content) + if extra == "auto_increment" { + col.IsAutoIncrement = true + } + } + } + if col.SQLType.IsText() { + if col.Default != "" { + col.Default = "'" + col.Default + "'" + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + return colSeq, cols, nil } func (db *mysql) GetTables() ([]*Table, error) { - args := []interface{}{db.dbname} - s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{db.dbName} + s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "TABLE_NAME": - table.Name = strings.Trim(string(content), "` ") - case "ENGINE": - } - } - tables = append(tables, table) - } - return tables, nil + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "TABLE_NAME": + table.Name = strings.Trim(string(content), "` ") + case "ENGINE": + } + } + tables = append(tables, table) + } + return tables, nil } func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{db.dbname, tableName} - s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{db.dbName, tableName} + s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - indexes := make(map[string]*Index, 0) - for _, record := range res { - var indexType int - var indexName, colName string - for name, content := range record { - switch name { - case "NON_UNIQUE": - if "YES" == string(content) || string(content) == "1" { - indexType = IndexType - } else { - indexType = UniqueType - } - case "INDEX_NAME": - indexName = string(content) - case "COLUMN_NAME": - colName = strings.Trim(string(content), "` ") - } - } - if indexName == "PRIMARY" { - continue - } - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - indexName = indexName[5+len(tableName) : len(indexName)] - } + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName, colName string + for name, content := range record { + switch name { + case "NON_UNIQUE": + if "YES" == string(content) || string(content) == "1" { + indexType = IndexType + } else { + indexType = UniqueType + } + case "INDEX_NAME": + indexName = string(content) + case "COLUMN_NAME": + colName = strings.Trim(string(content), "` ") + } + } + if indexName == "PRIMARY" { + continue + } + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + indexName = indexName[5+len(tableName) : len(indexName)] + } - var index *Index - var ok bool - if index, ok = indexes[indexName]; !ok { - index = new(Index) - index.Type = indexType - index.Name = indexName - indexes[indexName] = index - } - index.AddColumn(colName) - } - return indexes, nil + var index *Index + var ok bool + if index, ok = indexes[indexName]; !ok { + index = new(Index) + index.Type = indexType + index.Name = indexName + indexes[indexName] = index + } + index.AddColumn(colName) + } + return indexes, nil } diff --git a/postgres.go b/postgres.go index 7b716c06..c316f9b5 100644 --- a/postgres.go +++ b/postgres.go @@ -1,300 +1,305 @@ package xorm import ( - "database/sql" - "errors" - "fmt" - "strconv" - "strings" + "database/sql" + "errors" + "fmt" + "strconv" + "strings" ) type postgres struct { - base - dbname string + base } type values map[string]string func (vs values) Set(k, v string) { - vs[k] = v + vs[k] = v } func (vs values) Get(k string) (v string) { - return vs[k] + return vs[k] } func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) + panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } func parseOpts(name string, o values) { - if len(name) == 0 { - return - } + if len(name) == 0 { + return + } - name = strings.TrimSpace(name) + name = strings.TrimSpace(name) - ps := strings.Split(name, " ") - for _, p := range ps { - kv := strings.Split(p, "=") - if len(kv) < 2 { - errorf("invalid option: %q", p) - } - o.Set(kv[0], kv[1]) - } + ps := strings.Split(name, " ") + for _, p := range ps { + kv := strings.Split(p, "=") + if len(kv) < 2 { + errorf("invalid option: %q", p) + } + o.Set(kv[0], kv[1]) + } +} + +type postgresParser struct { +} + +func (p *postgresParser) parse(driverName, dataSourceName string) (*uri, error) { + db := &uri{dbType: POSTGRES} + o := make(values) + parseOpts(dataSourceName, o) + + db.dbName = o.Get("dbname") + if db.dbName == "" { + return nil, errors.New("dbname is empty") + } + return db, nil } func (db *postgres) Init(drivername, uri string) error { - db.base.init(drivername, uri) - - o := make(values) - parseOpts(uri, o) - - db.dbname = o.Get("dbname") - if db.dbname == "" { - return errors.New("dbname is empty") - } - return nil + return db.base.init(&postgresParser{}, drivername, uri) } func (db *postgres) SqlType(c *Column) string { - var res string - switch t := c.SQLType.Name; t { - case TinyInt: - res = SmallInt - case MediumInt, Int, Integer: - return Integer - case Serial, BigSerial: - c.IsAutoIncrement = true - c.Nullable = false - res = t - case Binary, VarBinary: - return Bytea - case DateTime: - res = TimeStamp - case TimeStampz: - return "timestamp with time zone" - case Float: - res = Real - case TinyText, MediumText, LongText: - res = Text - case Blob, TinyBlob, MediumBlob, LongBlob: - return Bytea - case Double: - return "DOUBLE PRECISION" - default: - if c.IsAutoIncrement { - return Serial - } - res = t - } + var res string + switch t := c.SQLType.Name; t { + case TinyInt: + res = SmallInt + case MediumInt, Int, Integer: + return Integer + case Serial, BigSerial: + c.IsAutoIncrement = true + c.Nullable = false + res = t + case Binary, VarBinary: + return Bytea + case DateTime: + res = TimeStamp + case TimeStampz: + return "timestamp with time zone" + case Float: + res = Real + case TinyText, MediumText, LongText: + res = Text + case Blob, TinyBlob, MediumBlob, LongBlob: + return Bytea + case Double: + return "DOUBLE PRECISION" + default: + if c.IsAutoIncrement { + return Serial + } + res = t + } - var hasLen1 bool = (c.Length > 0) - var hasLen2 bool = (c.Length2 > 0) - if hasLen1 { - res += "(" + strconv.Itoa(c.Length) + ")" - } else if hasLen2 { - res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" - } - return res + var hasLen1 bool = (c.Length > 0) + var hasLen2 bool = (c.Length2 > 0) + if hasLen1 { + res += "(" + strconv.Itoa(c.Length) + ")" + } else if hasLen2 { + res += "(" + strconv.Itoa(c.Length) + "," + strconv.Itoa(c.Length2) + ")" + } + return res } func (db *postgres) SupportInsertMany() bool { - return true + return true } func (db *postgres) QuoteStr() string { - return "\"" + return "\"" } func (db *postgres) AutoIncrStr() string { - return "" + return "" } func (db *postgres) SupportEngine() bool { - return false + return false } func (db *postgres) SupportCharset() bool { - return false + return false } func (db *postgres) IndexOnTable() bool { - return false + return false } func (db *postgres) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{tableName, idxName} - return `SELECT indexname FROM pg_indexes ` + - `WHERE tablename = ? AND indexname = ?`, args + args := []interface{}{tableName, idxName} + return `SELECT indexname FROM pg_indexes ` + + `WHERE tablename = ? AND indexname = ?`, args } func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args + args := []interface{}{tableName} + return `SELECT tablename FROM pg_tables WHERE tablename = ?`, args } func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName, colName} - return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + - " AND column_name = ?", args + args := []interface{}{tableName, colName} + return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" + + " AND column_name = ?", args } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{tableName} - s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + - ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + args := []interface{}{tableName} + s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + + ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, record := range res { - col := new(Column) - col.Indexes = make(map[string]bool) - for name, content := range record { - switch name { - case "column_name": - col.Name = strings.Trim(string(content), `" `) - case "column_default": - if strings.HasPrefix(string(content), "nextval") { - col.IsPrimaryKey = true - } else { - col.Default = string(content) - } - case "is_nullable": - if string(content) == "YES" { - col.Nullable = true - } else { - col.Nullable = false - } - case "data_type": - ct := string(content) - switch ct { - case "character varying", "character": - col.SQLType = SQLType{Varchar, 0, 0} - case "timestamp without time zone": - col.SQLType = SQLType{DateTime, 0, 0} - case "timestamp with time zone": - col.SQLType = SQLType{TimeStampz, 0, 0} - case "double precision": - col.SQLType = SQLType{Double, 0, 0} - case "boolean": - col.SQLType = SQLType{Bool, 0, 0} - case "time without time zone": - col.SQLType = SQLType{Time, 0, 0} - default: - col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} - } - if _, ok := sqlTypes[col.SQLType.Name]; !ok { - return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) - } - case "character_maximum_length": - i, err := strconv.Atoi(string(content)) - if err != nil { - return nil, nil, errors.New("retrieve length error") - } - col.Length = i - case "numeric_precision": - case "numeric_precision_radix": - } - } - if col.SQLType.IsText() { - if col.Default != "" { - col.Default = "'" + col.Default + "'" - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, record := range res { + col := new(Column) + col.Indexes = make(map[string]bool) + for name, content := range record { + switch name { + case "column_name": + col.Name = strings.Trim(string(content), `" `) + case "column_default": + if strings.HasPrefix(string(content), "nextval") { + col.IsPrimaryKey = true + } else { + col.Default = string(content) + } + case "is_nullable": + if string(content) == "YES" { + col.Nullable = true + } else { + col.Nullable = false + } + case "data_type": + ct := string(content) + switch ct { + case "character varying", "character": + col.SQLType = SQLType{Varchar, 0, 0} + case "timestamp without time zone": + col.SQLType = SQLType{DateTime, 0, 0} + case "timestamp with time zone": + col.SQLType = SQLType{TimeStampz, 0, 0} + case "double precision": + col.SQLType = SQLType{Double, 0, 0} + case "boolean": + col.SQLType = SQLType{Bool, 0, 0} + case "time without time zone": + col.SQLType = SQLType{Time, 0, 0} + default: + col.SQLType = SQLType{strings.ToUpper(ct), 0, 0} + } + if _, ok := sqlTypes[col.SQLType.Name]; !ok { + return nil, nil, errors.New(fmt.Sprintf("unkonw colType %v", ct)) + } + case "character_maximum_length": + i, err := strconv.Atoi(string(content)) + if err != nil { + return nil, nil, errors.New("retrieve length error") + } + col.Length = i + case "numeric_precision": + case "numeric_precision_radix": + } + } + if col.SQLType.IsText() { + if col.Default != "" { + col.Default = "'" + col.Default + "'" + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } - return colSeq, cols, nil + return colSeq, cols, nil } func (db *postgres) GetTables() ([]*Table, error) { - args := []interface{}{} - s := "SELECT tablename FROM pg_tables where schemaname = 'public'" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{} + s := "SELECT tablename FROM pg_tables where schemaname = 'public'" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "tablename": - table.Name = string(content) - } - } - tables = append(tables, table) - } - return tables, nil + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "tablename": + table.Name = string(content) + } + } + tables = append(tables, table) + } + return tables, nil } func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{tableName} - s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" + args := []interface{}{tableName} + s := "SELECT tablename, indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - indexes := make(map[string]*Index, 0) - for _, record := range res { - var indexType int - var indexName string - var colNames []string + indexes := make(map[string]*Index, 0) + for _, record := range res { + var indexType int + var indexName string + var colNames []string - for name, content := range record { - switch name { - case "indexname": - indexName = strings.Trim(string(content), `" `) - case "indexdef": - c := string(content) - if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { - indexType = UniqueType - } else { - indexType = IndexType - } - cs := strings.Split(c, "(") - colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") - } - } - if strings.HasSuffix(indexName, "_pkey") { - continue - } - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - newIdxName := indexName[5+len(tableName) : len(indexName)] - if newIdxName != "" { - indexName = newIdxName - } - } + for name, content := range record { + switch name { + case "indexname": + indexName = strings.Trim(string(content), `" `) + case "indexdef": + c := string(content) + if strings.HasPrefix(c, "CREATE UNIQUE INDEX") { + indexType = UniqueType + } else { + indexType = IndexType + } + cs := strings.Split(c, "(") + colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") + } + } + if strings.HasSuffix(indexName, "_pkey") { + continue + } + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + newIdxName := indexName[5+len(tableName) : len(indexName)] + if newIdxName != "" { + indexName = newIdxName + } + } - index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} - for _, colName := range colNames { - index.Cols = append(index.Cols, strings.Trim(colName, `" `)) - } - indexes[index.Name] = index - } - return indexes, nil + index := &Index{Name: indexName, Type: indexType, Cols: make([]string, 0)} + for _, colName := range colNames { + index.Cols = append(index.Cols, strings.Trim(colName, `" `)) + } + indexes[index.Name] = index + } + return indexes, nil } diff --git a/sqlite3.go b/sqlite3.go index eb42e999..84a9d1b0 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -1,223 +1,229 @@ package xorm import ( - "database/sql" - "strings" + "database/sql" + "strings" ) type sqlite3 struct { - base + base +} + +type sqlite3Parser struct { +} + +func (p *sqlite3Parser) parse(driverName, dataSourceName string) (*uri, error) { + return &uri{dbType: SQLITE, dbName: dataSourceName}, nil } func (db *sqlite3) Init(drivername, dataSourceName string) error { - db.base.init(drivername, dataSourceName) - return nil + return db.base.init(&sqlite3Parser{}, drivername, dataSourceName) } func (db *sqlite3) SqlType(c *Column) string { - switch t := c.SQLType.Name; t { - case Date, DateTime, TimeStamp, Time: - return Numeric - case TimeStampz: - return Text - case Char, Varchar, TinyText, Text, MediumText, LongText: - return Text - case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: - return Integer - case Float, Double, Real: - return Real - case Decimal, Numeric: - return Numeric - case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: - return Blob - case Serial, BigSerial: - c.IsPrimaryKey = true - c.IsAutoIncrement = true - c.Nullable = false - return Integer - default: - return t - } + switch t := c.SQLType.Name; t { + case Date, DateTime, TimeStamp, Time: + return Numeric + case TimeStampz: + return Text + case Char, Varchar, TinyText, Text, MediumText, LongText: + return Text + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, BigInt, Bool: + return Integer + case Float, Double, Real: + return Real + case Decimal, Numeric: + return Numeric + case TinyBlob, Blob, MediumBlob, LongBlob, Bytea, Binary, VarBinary: + return Blob + case Serial, BigSerial: + c.IsPrimaryKey = true + c.IsAutoIncrement = true + c.Nullable = false + return Integer + default: + return t + } } func (db *sqlite3) SupportInsertMany() bool { - return true + return true } func (db *sqlite3) QuoteStr() string { - return "`" + return "`" } func (db *sqlite3) AutoIncrStr() string { - return "AUTOINCREMENT" + return "AUTOINCREMENT" } func (db *sqlite3) SupportEngine() bool { - return false + return false } func (db *sqlite3) SupportCharset() bool { - return false + return false } func (db *sqlite3) IndexOnTable() bool { - return false + return false } func (db *sqlite3) IndexCheckSql(tableName, idxName string) (string, []interface{}) { - args := []interface{}{idxName} - return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args + args := []interface{}{idxName} + return "SELECT name FROM sqlite_master WHERE type='index' and name = ?", args } func (db *sqlite3) TableCheckSql(tableName string) (string, []interface{}) { - args := []interface{}{tableName} - return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args + args := []interface{}{tableName} + return "SELECT name FROM sqlite_master WHERE type='table' and name = ?", args } func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { - args := []interface{}{tableName} - sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" - return sql, args + args := []interface{}{tableName} + sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" + return sql, args } func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*Column, error) { - args := []interface{}{tableName} - s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, nil, err - } + args := []interface{}{tableName} + s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, nil, err + } - var sql string - for _, record := range res { - for name, content := range record { - if name == "sql" { - sql = string(content) - } - } - } + var sql string + for _, record := range res { + for name, content := range record { + if name == "sql" { + sql = string(content) + } + } + } - nStart := strings.Index(sql, "(") - nEnd := strings.Index(sql, ")") - colCreates := strings.Split(sql[nStart+1:nEnd], ",") - cols := make(map[string]*Column) - colSeq := make([]string, 0) - for _, colStr := range colCreates { - fields := strings.Fields(strings.TrimSpace(colStr)) - col := new(Column) - col.Indexes = make(map[string]bool) - col.Nullable = true - for idx, field := range fields { - if idx == 0 { - col.Name = strings.Trim(field, "`[] ") - continue - } else if idx == 1 { - col.SQLType = SQLType{field, 0, 0} - } - switch field { - case "PRIMARY": - col.IsPrimaryKey = true - case "AUTOINCREMENT": - col.IsAutoIncrement = true - case "NULL": - if fields[idx-1] == "NOT" { - col.Nullable = false - } else { - col.Nullable = true - } - } - } - cols[col.Name] = col - colSeq = append(colSeq, col.Name) - } - return colSeq, cols, nil + nStart := strings.Index(sql, "(") + nEnd := strings.Index(sql, ")") + colCreates := strings.Split(sql[nStart+1:nEnd], ",") + cols := make(map[string]*Column) + colSeq := make([]string, 0) + for _, colStr := range colCreates { + fields := strings.Fields(strings.TrimSpace(colStr)) + col := new(Column) + col.Indexes = make(map[string]bool) + col.Nullable = true + for idx, field := range fields { + if idx == 0 { + col.Name = strings.Trim(field, "`[] ") + continue + } else if idx == 1 { + col.SQLType = SQLType{field, 0, 0} + } + switch field { + case "PRIMARY": + col.IsPrimaryKey = true + case "AUTOINCREMENT": + col.IsAutoIncrement = true + case "NULL": + if fields[idx-1] == "NOT" { + col.Nullable = false + } else { + col.Nullable = true + } + } + } + cols[col.Name] = col + colSeq = append(colSeq, col.Name) + } + return colSeq, cols, nil } func (db *sqlite3) GetTables() ([]*Table, error) { - args := []interface{}{} - s := "SELECT name FROM sqlite_master WHERE type='table'" + args := []interface{}{} + s := "SELECT name FROM sqlite_master WHERE type='table'" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - tables := make([]*Table, 0) - for _, record := range res { - table := new(Table) - for name, content := range record { - switch name { - case "name": - table.Name = string(content) - } - } - if table.Name == "sqlite_sequence" { - continue - } - tables = append(tables, table) - } - return tables, nil + tables := make([]*Table, 0) + for _, record := range res { + table := new(Table) + for name, content := range record { + switch name { + case "name": + table.Name = string(content) + } + } + if table.Name == "sqlite_sequence" { + continue + } + tables = append(tables, table) + } + return tables, nil } func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) { - args := []interface{}{tableName} - s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - cnn, err := sql.Open(db.drivername, db.dataSourceName) - if err != nil { - return nil, err - } - defer cnn.Close() - res, err := query(cnn, s, args...) - if err != nil { - return nil, err - } + args := []interface{}{tableName} + s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" + cnn, err := sql.Open(db.driverName, db.dataSourceName) + if err != nil { + return nil, err + } + defer cnn.Close() + res, err := query(cnn, s, args...) + if err != nil { + return nil, err + } - indexes := make(map[string]*Index, 0) - for _, record := range res { - var sql string - index := new(Index) - for name, content := range record { - if name == "sql" { - sql = string(content) - } - } + indexes := make(map[string]*Index, 0) + for _, record := range res { + var sql string + index := new(Index) + for name, content := range record { + if name == "sql" { + sql = string(content) + } + } - nNStart := strings.Index(sql, "INDEX") - nNEnd := strings.Index(sql, "ON") - indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") - //fmt.Println(indexName) - if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { - index.Name = indexName[5+len(tableName) : len(indexName)] - } else { - index.Name = indexName - } + nNStart := strings.Index(sql, "INDEX") + nNEnd := strings.Index(sql, "ON") + indexName := strings.Trim(sql[nNStart+6:nNEnd], "` []") + //fmt.Println(indexName) + if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { + index.Name = indexName[5+len(tableName) : len(indexName)] + } else { + index.Name = indexName + } - if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { - index.Type = UniqueType - } else { - index.Type = IndexType - } + if strings.HasPrefix(sql, "CREATE UNIQUE INDEX") { + index.Type = UniqueType + } else { + index.Type = IndexType + } - nStart := strings.Index(sql, "(") - nEnd := strings.Index(sql, ")") - colIndexes := strings.Split(sql[nStart+1:nEnd], ",") + nStart := strings.Index(sql, "(") + nEnd := strings.Index(sql, ")") + colIndexes := strings.Split(sql[nStart+1:nEnd], ",") - index.Cols = make([]string, 0) - for _, col := range colIndexes { - index.Cols = append(index.Cols, strings.Trim(col, "` []")) - } - indexes[index.Name] = index - } + index.Cols = make([]string, 0) + for _, col := range colIndexes { + index.Cols = append(index.Cols, strings.Trim(col, "` []")) + } + indexes[index.Name] = index + } - return indexes, nil + return indexes, nil }