added xorm reverse tool

This commit is contained in:
Lunny Xiao 2013-10-12 23:16:51 +08:00
parent 58f95d6f27
commit 2caed88b82
20 changed files with 1052 additions and 116 deletions

View File

@ -3,6 +3,7 @@ package xorm
import (
"errors"
"fmt"
"strings"
"testing"
"time"
)
@ -98,7 +99,7 @@ func insert(engine *Engine, t *testing.T) {
}
}
func query(engine *Engine, t *testing.T) {
func testQuery(engine *Engine, t *testing.T) {
sql := "select * from userinfo"
results, err := engine.Query(sql)
if err != nil {
@ -163,6 +164,19 @@ func insertMulti(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
users2 := []*Userinfo{
&Userinfo{Username: "1xlw", Departname: "dev", Alias: "lunny2", Created: time.Now()},
&Userinfo{Username: "1xlw2", Departname: "dev", Alias: "lunny3", Created: time.Now()},
&Userinfo{Username: "1xlw11", Departname: "dev", Alias: "lunny2", Created: time.Now()},
&Userinfo{Username: "1xlw22", Departname: "dev", Alias: "lunny3", Created: time.Now()},
}
_, err = engine.Insert(&users2)
if err != nil {
t.Error(err)
panic(err)
}
}
func insertTwoTable(engine *Engine, t *testing.T) {
@ -1018,6 +1032,18 @@ func testIndexAndUnique(engine *Engine, t *testing.T) {
t.Error(err)
//panic(err)
}
err = engine.CreateIndexes(&IndexOrUnique{})
if err != nil {
t.Error(err)
//panic(err)
}
err = engine.CreateUniques(&IndexOrUnique{})
if err != nil {
t.Error(err)
//panic(err)
}
}
type IntId struct {
@ -1042,6 +1068,12 @@ func testIntId(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
_, err = engine.Insert(&IntId{Name: "test"})
if err != nil {
t.Error(err)
panic(err)
}
}
func testInt32Id(engine *Engine, t *testing.T) {
@ -1056,6 +1088,31 @@ func testInt32Id(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
_, err = engine.Insert(&Int32Id{Name: "test"})
if err != nil {
t.Error(err)
panic(err)
}
}
func testMetaInfo(engine *Engine, t *testing.T) {
tables, err := engine.DBMetas()
if err != nil {
t.Error(err)
panic(err)
}
for _, table := range tables {
fmt.Println(table.Name)
for _, col := range table.Columns {
fmt.Println(col.String(engine.dialect))
}
for _, index := range table.Indexes {
fmt.Println(index.Name, index.Type, strings.Join(index.Cols, ","))
}
}
}
func testAll(engine *Engine, t *testing.T) {
@ -1066,7 +1123,7 @@ func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- insert --------------")
insert(engine, t)
fmt.Println("-------------- query --------------")
query(engine, t)
testQuery(engine, t)
fmt.Println("-------------- exec --------------")
exec(engine, t)
fmt.Println("-------------- insertAutoIncr --------------")
@ -1132,6 +1189,12 @@ func testAll2(engine *Engine, t *testing.T) {
testCreatedAndUpdated(engine, t)
fmt.Println("-------------- testIndexAndUnique --------------")
testIndexAndUnique(engine, t)
fmt.Println("-------------- testIntId --------------")
//testIntId(engine, t)
fmt.Println("-------------- testInt32Id --------------")
//testInt32Id(engine, t)
fmt.Println("-------------- testMetaInfo --------------")
testMetaInfo(engine, t)
fmt.Println("-------------- transaction --------------")
transaction(engine, t)
}

129
engine.go
View File

@ -20,7 +20,7 @@ const (
// a dialect is a driver's wrapper
type dialect interface {
Init(uri string) error
Init(DriverName, DataSourceName string) error
SqlType(t *Column) string
SupportInsertMany() bool
QuoteStr() string
@ -31,6 +31,10 @@ type dialect interface {
IndexCheckSql(tableName, idxName string) (string, []interface{})
TableCheckSql(tableName string) (string, []interface{})
ColumnCheckSql(tableName, colName string) (string, []interface{})
GetColumns(tableName string) (map[string]*Column, error)
GetTables() ([]*Table, error)
GetIndexes(tableName string) (map[string]*Index, error)
}
type Engine struct {
@ -38,7 +42,7 @@ type Engine struct {
TagIdentifier string
DriverName string
DataSourceName string
Dialect dialect
dialect dialect
Tables map[reflect.Type]*Table
mutex *sync.Mutex
ShowSQL bool
@ -57,28 +61,28 @@ type Engine struct {
// When the return is ture, then engine.Insert(&users) will
// generate batch sql and exeute.
func (engine *Engine) SupportInsertMany() bool {
return engine.Dialect.SupportInsertMany()
return engine.dialect.SupportInsertMany()
}
// Engine's database use which charactor as quote.
// mysql, sqlite use ` and postgres use "
func (engine *Engine) QuoteStr() string {
return engine.Dialect.QuoteStr()
return engine.dialect.QuoteStr()
}
// Use QuoteStr quote the string sql
func (engine *Engine) Quote(sql string) string {
return engine.Dialect.QuoteStr() + sql + engine.Dialect.QuoteStr()
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
}
// A simple wrapper to dialect's SqlType method
func (engine *Engine) SqlType(c *Column) string {
return engine.Dialect.SqlType(c)
return engine.dialect.SqlType(c)
}
// Database's autoincrement statement
func (engine *Engine) AutoIncrStr() string {
return engine.Dialect.AutoIncrStr()
return engine.dialect.AutoIncrStr()
}
// Set engine's pool, the pool default is Go's standard library's connection pool.
@ -178,6 +182,28 @@ func (engine *Engine) NoAutoTime() *Session {
return session.NoAutoTime()
}
func (engine *Engine) DBMetas() ([]*Table, error) {
tables, err := engine.dialect.GetTables()
if err != nil {
return nil, err
}
for _, table := range tables {
cols, err := engine.dialect.GetColumns(table.Name)
if err != nil {
return nil, err
}
table.Columns = cols
indexes, err := engine.dialect.GetIndexes(table.Name)
if err != nil {
return nil, err
}
table.Indexes = indexes
}
return tables, nil
}
func (engine *Engine) Cascade(trueOrFalse ...bool) *Session {
session := engine.NewSession()
session.IsAutoClose = true
@ -316,7 +342,7 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
if ormTagStr != "" {
col = &Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false,
IsAutoIncrement: false, MapType: TWOSIDES}
IsAutoIncrement: false, MapType: TWOSIDES, Indexes: make(map[string]bool)}
tags := strings.Split(ormTagStr, " ")
if len(tags) > 0 {
@ -335,6 +361,8 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
table.PrimaryKey = parentTable.PrimaryKey
continue
}
var indexType int
var indexName string
for j, key := range tags {
k := strings.ToUpper(key)
switch {
@ -358,37 +386,15 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
/*case strings.HasPrefix(k, "--"):
col.Comment = k[2:len(k)]*/
case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
indexName := k[len("INDEX")+1 : len(k)-1]
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col)
col.Index = index
} else {
index := NewIndex(indexName, false)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
}
indexType = IndexType
indexName = k[len("INDEX")+1 : len(k)-1]
case k == "INDEX":
index := NewIndex(col.Name, false)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
indexType = IndexType
case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
indexName := k[len("UNIQUE")+1 : len(k)-1]
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col)
col.Index = index
} else {
index := NewIndex(indexName, true)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
}
indexName = k[len("UNIQUE")+1 : len(k)-1]
indexType = UniqueType
case k == "UNIQUE":
index := NewIndex(col.Name, true)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
indexType = UniqueType
case k == "NOTNULL":
col.Nullable = false
case k == "NOT":
@ -432,12 +438,39 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
if col.Name == "" {
col.Name = engine.Mapper.Obj2Table(t.Field(i).Name)
}
if indexType == IndexType {
if indexName == "" {
indexName = col.Name
}
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col.Name)
col.Indexes[index.Name] = true
} else {
index := NewIndex(indexName, IndexType)
index.AddColumn(col.Name)
table.AddIndex(index)
col.Indexes[index.Name] = true
}
} else if indexType == UniqueType {
if indexName == "" {
indexName = col.Name
}
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col.Name)
col.Indexes[index.Name] = true
} else {
index := NewIndex(indexName, UniqueType)
index.AddColumn(col.Name)
table.AddIndex(index)
col.Indexes[index.Name] = true
}
}
}
} else {
sqlType := Type2SQLType(fieldType)
col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", nil, false, false,
TWOSIDES, false, false, ""}
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", make(map[string]bool), false, false,
TWOSIDES, false, false, false}
}
if col.IsAutoIncrement {
col.Nullable = false
@ -498,6 +531,20 @@ func (engine *Engine) IsTableExist(bean interface{}) (bool, error) {
return has, err
}
// create indexes
func (engine *Engine) CreateIndexes(bean interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.CreateIndexes(bean)
}
// create uniques
func (engine *Engine) CreateUniques(bean interface{}) error {
session := engine.NewSession()
defer session.Close()
return session.CreateUniques(bean)
}
// If enabled cache, clear the cache bean
func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error {
t := rType(bean)
@ -585,7 +632,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
if index.IsUnique {
if index.Type == UniqueType {
isExist, err := session.isIndexExist(table.Name, name, true)
if err != nil {
return err
@ -599,7 +646,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
}
} else {
} else if index.Type == IndexType {
isExist, err := session.isIndexExist(table.Name, name, false)
if err != nil {
return err
@ -613,6 +660,8 @@ func (engine *Engine) Sync(beans ...interface{}) error {
return err
}
}
} else {
return errors.New("unknow index type")
}
}
}

View File

@ -5,10 +5,11 @@ import (
)
var (
ErrParamsType error = errors.New("params type error")
ErrTableNotFound error = errors.New("not found table")
ErrUnSupportedType error = errors.New("unsupported type error")
ErrNotExist error = errors.New("not exist error")
ErrCacheFailed error = errors.New("cache failed")
ErrNeedDeletedCond error = errors.New("delete need at least one condition")
ErrParamsType error = errors.New("Params type error")
ErrTableNotFound error = errors.New("Not found table")
ErrUnSupportedType error = errors.New("Unsupported type error")
ErrNotExist error = errors.New("Not exist error")
ErrCacheFailed error = errors.New("Cache failed")
ErrNeedDeletedCond error = errors.New("Delete need at least one condition")
ErrNotImplemented error = errors.New("Not implemented.")
)

View File

@ -12,6 +12,11 @@ type SyncUser struct {
Id int64
Name string `xorm:"unique"`
Age int `xorm:"index"`
Title string
Address string
Genre string
Area string
Date int
}
type SyncLoginInfo struct {
@ -61,5 +66,19 @@ func main() {
if err != nil {
fmt.Println(err)
}
user := &SyncUser{
Name: "testsdf",
Age: 15,
Title: "newsfds",
Address: "fasfdsafdsaf",
Genre: "fsafd",
Area: "fafdsafd",
Date: 1000,
}
_, err = Orm.Insert(user)
if err != nil {
fmt.Println(err)
}
}
}

View File

@ -75,7 +75,9 @@ func titleCasedName(name string) string {
switch {
case upNextChar:
upNextChar = false
if 'a' <= chr && chr <= 'z' {
chr -= ('a' - 'A')
}
case chr == '_':
upNextChar = true
continue

View File

@ -17,7 +17,8 @@ type mymysql struct {
passwd string
}
func (db *mymysql) Init(uri string) error {
func (db *mymysql) Init(drivername, uri string) error {
db.mysql.base.init(drivername, uri)
pd := strings.SplitN(uri, "*", 2)
if len(pd) == 2 {
// Parse protocol part of URI

165
mysql.go
View File

@ -2,14 +2,26 @@ package xorm
import (
"crypto/tls"
//"fmt"
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
//"strings"
"strings"
"time"
)
type base struct {
drivername string
dataSourceName string
}
func (b *base) init(drivername, dataSourceName string) {
b.drivername, b.dataSourceName = drivername, dataSourceName
}
type mysql struct {
base
user string
passwd string
net string
@ -56,7 +68,8 @@ func (cfg *mysql) parseDSN(dsn string) (err error) {
return
}
func (db *mysql) Init(uri string) error {
func (db *mysql) Init(drivername, uri string) error {
db.base.init(drivername, uri)
return db.parseDSN(uri)
}
@ -133,3 +146,149 @@ func (db *mysql) TableCheckSql(tableName string) (string, []interface{}) {
sql := "SELECT `TABLE_NAME` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=? and `TABLE_NAME`=?"
return sql, args
}
func (db *mysql) GetColumns(tableName string) (map[string]*Column, error) {
args := []interface{}{db.dbname, tableName}
s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
cols := make(map[string]*Column)
for _, record := range res {
col := new(Column)
for name, content := range record {
switch name {
case "COLUMN_NAME":
col.Name = string(content)
case "IS_NULLABLE":
if "YES" == string(content) {
col.Nullable = true
}
case "COLUMN_DEFAULT":
// add ''
col.Default = string(content)
case "COLUMN_TYPE":
cts := strings.Split(string(content), "(")
var len1, len2 int
if len(cts) == 2 {
lens := strings.Split(cts[1][0:len(cts[1])-1], ",")
len1, err = strconv.Atoi(strings.TrimSpace(lens[0]))
if err != nil {
return nil, err
}
if len(lens) == 2 {
len2, err = strconv.Atoi(lens[1])
if err != nil {
return nil, err
}
}
}
colName := cts[0]
colType := strings.ToUpper(colName)
col.Length = len1
col.Length2 = len2
if _, ok := sqlTypes[colType]; ok {
col.SQLType = SQLType{colType, len1, len2}
} else {
return nil, errors.New(fmt.Sprintf("unkonw colType %v", colType))
}
case "COLUMN_KEY":
key := string(content)
if key == "PRI" {
col.IsPrimaryKey = true
}
if key == "UNI" {
//col.is
}
case "EXTRA":
extra := string(content)
if extra == "auto_increment" {
col.IsAutoIncrement = true
}
}
}
cols[col.Name] = col
}
return cols, nil
}
func (db *mysql) GetTables() ([]*Table, error) {
args := []interface{}{db.dbname}
s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
tables := make([]*Table, 0)
for _, record := range res {
table := new(Table)
for name, content := range record {
switch name {
case "TABLE_NAME":
table.Name = string(content)
case "ENGINE":
}
}
tables = append(tables, table)
}
return tables, nil
}
func (db *mysql) GetIndexes(tableName string) (map[string]*Index, error) {
args := []interface{}{db.dbname, tableName}
s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
indexes := make(map[string]*Index, 0)
for _, record := range res {
var indexType int
var indexName, colName string
for name, content := range record {
switch name {
case "NON_UNIQUE":
if "YES" == string(content) {
indexType = IndexType
} else {
indexType = UniqueType
}
case "INDEX_NAME":
indexName = string(content)
case "COLUMN_NAME":
colName = string(content)
}
}
if indexName == "PRIMARY" {
continue
}
indexName = indexName[5+len(tableName) : len(indexName)]
var index *Index
var ok bool
if index, ok = indexes[indexName]; !ok {
index = new(Index)
index.Type = indexType
index.Name = indexName
indexes[indexName] = index
}
index.AddColumn(colName)
}
return indexes, nil
}

View File

@ -8,6 +8,7 @@ import (
)
type postgres struct {
base
dbname string
}
@ -44,7 +45,9 @@ func parseOpts(name string, o values) {
}
}
func (db *postgres) Init(uri string) error {
func (db *postgres) Init(drivername, uri string) error {
db.base.init(drivername, uri)
o := make(values)
parseOpts(uri, o)
@ -135,3 +138,43 @@ func (db *postgres) ColumnCheckSql(tableName, colName string) (string, []interfa
return "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?" +
" AND column_name = ?", args
}
func (db *postgres) GetColumns(tableName string) (map[string]*Column, error) {
args := []interface{}{tableName}
s := "SELECT COLUMN_NAME, column_default, is_nullable, data_type, character_maximum_length" +
" FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = ?"
cnn, err := sql.Open(db.drivername, db.dataSourceName)
if err != nil {
return nil, err
}
res, err := query(cnn, s, args...)
if err != nil {
return nil, err
}
cols := make(map[string]*Column)
for _, record := range res {
col := new(Column)
for name, content := range record {
switch name {
case "COLUMN_NAME":
col.Name = string(content)
case "column_default":
if strings.HasPrefix(string(content), "") {
col.IsPrimaryKey
}
}
}
}
return nil, ErrNotImplemented
}
func (db *postgres) GetTables() ([]*Table, error) {
return nil, ErrNotImplemented
}
func (db *postgres) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, ErrNotImplemented
}

View File

@ -11,6 +11,8 @@ import (
"time"
)
// Struct Session keep a pointer to sql.DB and provides all execution of all
// kind of database operations.
type Session struct {
Db *sql.DB
Engine *Engine
@ -22,6 +24,7 @@ type Session struct {
IsAutoClose bool
}
// Method Init reset the session as the init status.
func (session *Session) Init() {
session.Statement = Statement{Engine: session.Engine}
session.Statement.Init()
@ -30,6 +33,7 @@ func (session *Session) Init() {
session.IsAutoClose = false
}
// Method Close release the connection from pool
func (session *Session) Close() {
defer func() {
if session.Db != nil {
@ -41,56 +45,64 @@ func (session *Session) Close() {
}()
}
// Method Sql provides raw sql input parameter. When you have a complex SQL statement
// and cannot use Where, Id, In and etc. Methods to describe, you can use Sql.
func (session *Session) Sql(querystring string, args ...interface{}) *Session {
session.Statement.Sql(querystring, args...)
return session
}
// Method Where provides custom query condition.
func (session *Session) Where(querystring string, args ...interface{}) *Session {
session.Statement.Where(querystring, args...)
return session
}
// Method Id provides converting id as a query condition
func (session *Session) Id(id int64) *Session {
session.Statement.Id(id)
return session
}
// Method Table can input a string or pointer to struct for special a table to operate.
func (session *Session) Table(tableNameOrBean interface{}) *Session {
session.Statement.Table(tableNameOrBean)
return session
}
// Method In provides a query string like "id in (1, 2, 3)"
func (session *Session) In(column string, args ...interface{}) *Session {
session.Statement.In(column, args...)
return session
}
// Method Cols provides some columns to special
func (session *Session) Cols(columns ...string) *Session {
session.Statement.Cols(columns...)
return session
}
// Method NoAutoTime means do not automatically give created field and updated field
// the current time on the current session temporarily
func (session *Session) NoAutoTime() *Session {
session.Statement.UseAutoTime = false
return session
}
/*func (session *Session) Trans(t string) *Session {
session.TransType = t
return session
}*/
// Method Limit provide limit and offset query condition
func (session *Session) Limit(limit int, start ...int) *Session {
session.Statement.Limit(limit, start...)
return session
}
// Method OrderBy provide order by query condition, the input parameter is the content
// after order by on a sql statement.
func (session *Session) OrderBy(order string) *Session {
session.Statement.OrderBy(order)
return session
}
// Method Desc provide desc order by query condition, the input parameters are columns.
func (session *Session) Desc(colNames ...string) *Session {
if session.Statement.OrderStr != "" {
session.Statement.OrderStr += ", "
@ -101,6 +113,7 @@ func (session *Session) Desc(colNames ...string) *Session {
return session
}
// Method Asc provide asc order by query condition, the input parameters are columns.
func (session *Session) Asc(colNames ...string) *Session {
if session.Statement.OrderStr != "" {
session.Statement.OrderStr += ", "
@ -111,16 +124,19 @@ func (session *Session) Asc(colNames ...string) *Session {
return session
}
// Method StoreEngine is only avialble mysql dialect currently
func (session *Session) StoreEngine(storeEngine string) *Session {
session.Statement.StoreEngine = storeEngine
return session
}
// Method StoreEngine is only avialble charset dialect currently
func (session *Session) Charset(charset string) *Session {
session.Statement.Charset = charset
return session
}
// Method Cascade
func (session *Session) Cascade(trueOrFalse ...bool) *Session {
if len(trueOrFalse) >= 1 {
session.Statement.UseCascade = trueOrFalse[0]
@ -128,6 +144,8 @@ func (session *Session) Cascade(trueOrFalse ...bool) *Session {
return session
}
// Method NoCache ask this session do not retrieve data from cache system and
// get data from database directly.
func (session *Session) NoCache() *Session {
session.Statement.UseCache = false
return session
@ -836,7 +854,7 @@ func (session *Session) isColumnExist(tableName, colName string) (bool, error) {
if session.IsAutoClose {
defer session.Close()
}
sql, args := session.Engine.Dialect.ColumnCheckSql(tableName, colName)
sql, args := session.Engine.dialect.ColumnCheckSql(tableName, colName)
results, err := session.query(sql, args...)
return len(results) > 0, err
}
@ -850,7 +868,7 @@ func (session *Session) isTableExist(tableName string) (bool, error) {
if session.IsAutoClose {
defer session.Close()
}
sql, args := session.Engine.Dialect.TableCheckSql(tableName)
sql, args := session.Engine.dialect.TableCheckSql(tableName)
results, err := session.query(sql, args...)
return len(results) > 0, err
}
@ -870,7 +888,7 @@ func (session *Session) isIndexExist(tableName, idxName string, unique bool) (bo
} else {
idx = indexName(tableName, idxName)
}
sql, args := session.Engine.Dialect.IndexCheckSql(tableName, idx)
sql, args := session.Engine.dialect.IndexCheckSql(tableName, idx)
results, err := session.query(sql, args...)
return len(results) > 0, err
}
@ -901,7 +919,7 @@ func (session *Session) addIndex(tableName, idxName string) error {
defer session.Close()
}
//fmt.Println(idxName)
cols := session.Statement.RefTable.Indexes[idxName].GenColsStr()
cols := session.Statement.RefTable.Indexes[idxName].Cols
sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols)
_, err = session.exec(sql, args...)
return err
@ -917,7 +935,7 @@ func (session *Session) addUnique(tableName, uqeName string) error {
defer session.Close()
}
//fmt.Println(uqeName, session.Statement.RefTable.Uniques)
cols := session.Statement.RefTable.Indexes[uqeName].GenColsStr()
cols := session.Statement.RefTable.Indexes[uqeName].Cols
sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols)
_, err = session.exec(sql, args...)
return err
@ -945,6 +963,79 @@ func (session *Session) DropAll() error {
return nil
}
func query(db *sql.DB, sql string, params ...interface{}) (resultsSlice []map[string][]byte, err error) {
s, err := db.Prepare(sql)
if err != nil {
return nil, err
}
defer s.Close()
res, err := s.Query(params...)
if err != nil {
return nil, err
}
defer res.Close()
fields, err := res.Columns()
if err != nil {
return nil, err
}
for res.Next() {
result := make(map[string][]byte)
var scanResultContainers []interface{}
for i := 0; i < len(fields); i++ {
var scanResultContainer interface{}
scanResultContainers = append(scanResultContainers, &scanResultContainer)
}
if err := res.Scan(scanResultContainers...); err != nil {
return nil, err
}
for ii, key := range fields {
rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
//if row is null then ignore
if rawValue.Interface() == nil {
//fmt.Println("ignore ...", key, rawValue)
continue
}
aa := reflect.TypeOf(rawValue.Interface())
vv := reflect.ValueOf(rawValue.Interface())
var str string
switch aa.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
str = strconv.FormatInt(vv.Int(), 10)
result[key] = []byte(str)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
str = strconv.FormatUint(vv.Uint(), 10)
result[key] = []byte(str)
case reflect.Float32, reflect.Float64:
str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
result[key] = []byte(str)
case reflect.Slice:
switch aa.Elem().Kind() {
case reflect.Uint8:
result[key] = rawValue.Interface().([]byte)
default:
//session.Engine.LogError("Unsupported type")
}
case reflect.String:
str = vv.String()
result[key] = []byte(str)
//时间类型
case reflect.Struct:
if aa.String() == "time.Time" {
str = rawValue.Interface().(time.Time).Format("2006-01-02 15:04:05.000 -0700")
result[key] = []byte(str)
} else {
//session.Engine.LogError("Unsupported struct type")
}
default:
//session.Engine.LogError("Unsupported type")
}
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
}
func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
for _, filter := range session.Engine.Filters {
sql = filter.Do(sql, session)
@ -953,7 +1044,9 @@ func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice
session.Engine.LogSQL(sql)
session.Engine.LogSQL(paramStr)
s, err := session.Db.Prepare(sql)
return query(session.Db, sql, paramStr...)
/*s, err := session.Db.Prepare(sql)
if err != nil {
return nil, err
}
@ -1022,7 +1115,7 @@ func (session *Session) query(sql string, paramStr ...interface{}) (resultsSlice
}
resultsSlice = append(resultsSlice, result)
}
return resultsSlice, nil
return resultsSlice, nil*/
}
func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) {
@ -1446,9 +1539,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
var v interface{} = id
switch pkValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int:
v = int(id)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
v = uint(id)
}
pkValue.Set(reflect.ValueOf(v))
@ -1456,6 +1549,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
return id, nil
}
// Method InsertOne insert only one struct into database as a record.
// The in parameter bean must a struct or a point to struct. The return
// parameter is lastInsertId and error
func (session *Session) InsertOne(bean interface{}) (int64, error) {
err := session.newDb()
if err != nil {

View File

@ -1,9 +1,11 @@
package xorm
type sqlite3 struct {
base
}
func (db *sqlite3) Init(uri string) error {
func (db *sqlite3) Init(drivername, dataSourceName string) error {
db.base.init(drivername, dataSourceName)
return nil
}
@ -69,3 +71,22 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac
args := []interface{}{tableName}
return "SELECT name FROM sqlite_master WHERE type='table' and name = ? and sql like '%`" + colName + "`%'", args
}
func (db *sqlite3) GetColumns(tableName string) (map[string]*Column, error) {
/*args := []interface{}{db.dbname, tableName}
SELECT sql FROM sqlite_master WHERE type='table' and name = 'category';
sql := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," +
" `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?"
return sql, args*/
return nil, ErrNotImplemented
}
func (db *sqlite3) GetTables() ([]*Table, error) {
return nil, ErrNotImplemented
}
func (db *sqlite3) GetIndexes(tableName string) (map[string]*Index, error) {
return nil, ErrNotImplemented
}

View File

@ -262,15 +262,15 @@ func (statement *Statement) genCreateSQL() string {
sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " ("
for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[colName]
sql += col.String(statement.Engine)
sql += col.String(statement.Engine.dialect)
sql = strings.TrimSpace(sql)
sql += ", "
}
sql = sql[:len(sql)-2] + ")"
if statement.Engine.Dialect.SupportEngine() && statement.StoreEngine != "" {
if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" {
sql += " ENGINE=" + statement.StoreEngine
}
if statement.Engine.Dialect.SupportCharset() && statement.Charset != "" {
if statement.Engine.dialect.SupportCharset() && statement.Charset != "" {
sql += " DEFAULT CHARSET " + statement.Charset
}
sql += ";"
@ -286,10 +286,12 @@ func (s *Statement) genIndexSQL() []string {
tbName := s.TableName()
quote := s.Engine.Quote
for idxName, index := range s.RefTable.Indexes {
if index.Type == IndexType {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(index.GenColsStr(), quote(","))))
quote(tbName), quote(strings.Join(index.Cols, quote(","))))
sqls = append(sqls, sql)
}
}
return sqls
}
@ -302,10 +304,12 @@ func (s *Statement) genUniqueSQL() []string {
tbName := s.TableName()
quote := s.Engine.Quote
for idxName, unique := range s.RefTable.Indexes {
if unique.Type == UniqueType {
sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)),
quote(tbName), quote(strings.Join(unique.GenColsStr(), quote(","))))
quote(tbName), quote(strings.Join(unique.Cols, quote(","))))
sqls = append(sqls, sql)
}
}
return sqls
}
@ -313,13 +317,13 @@ func (s *Statement) genDelIndexSQL() []string {
var sqls []string = make([]string, 0)
for idxName, index := range s.RefTable.Indexes {
var rIdxName string
if index.IsUnique {
if index.Type == UniqueType {
rIdxName = uniqueName(s.TableName(), idxName)
} else {
} else if index.Type == IndexType {
rIdxName = indexName(s.TableName(), idxName)
}
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName))
if s.Engine.Dialect.IndexOnTable() {
if s.Engine.dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName()))
}
sqls = append(sqls, sql)
@ -351,7 +355,7 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) {
func (s *Statement) genAddColumnStr(col *Column) (string, []interface{}) {
quote := s.Engine.Quote
sql := fmt.Sprintf("ALTER TABLE %v ADD COLUMN %v;", quote(s.TableName()),
col.String(s.Engine))
col.String(s.Engine.dialect))
return sql, []interface{}{}
}

View File

@ -143,35 +143,57 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
return
}
func SQLType2Type(st SQLType) reflect.Type {
switch st.Name {
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
return reflect.TypeOf(1)
case BigInt, BigSerial:
return reflect.TypeOf(int64(1))
case Float, Real:
return reflect.TypeOf(float32(1))
case Double:
return reflect.TypeOf(float64(1))
case Char, Varchar, TinyText, Text, MediumText, LongText:
return reflect.TypeOf("")
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
return reflect.TypeOf([]byte{})
case Bool:
return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp:
return reflect.TypeOf(tm)
case Decimal, Numeric:
return reflect.TypeOf("")
default:
return reflect.TypeOf("")
}
}
const (
TWOSIDES = iota + 1
ONLYTODB
ONLYFROMDB
IndexType = iota + 1
UniqueType
)
type Index struct {
Name string
IsUnique bool
Cols []*Column
Type int
Cols []string
}
func (index *Index) AddColumn(cols ...*Column) {
func (index *Index) AddColumn(cols ...string) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
func (index *Index) GenColsStr() []string {
names := make([]string, len(index.Cols))
for idx, col := range index.Cols {
names[idx] = col.Name
}
return names
func NewIndex(name string, indexType int) *Index {
return &Index{name, indexType, make([]string, 0)}
}
func NewIndex(name string, isUnique bool) *Index {
return &Index{name, isUnique, make([]*Column, 0)}
}
const (
TWOSIDES = iota + 1
ONLYTODB
ONLYFROMDB
)
type Column struct {
Name string
@ -181,26 +203,26 @@ type Column struct {
Length2 int
Nullable bool
Default string
Index *Index
Indexes map[string]bool
IsPrimaryKey bool
IsAutoIncrement bool
MapType int
IsCreated bool
IsUpdated bool
Comment string
IsCascade bool
}
func (col *Column) String(engine *Engine) string {
sql := engine.Quote(col.Name) + " "
func (col *Column) String(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += engine.SqlType(col) + " "
sql += d.SqlType(col) + " "
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
}
if col.IsAutoIncrement {
sql += engine.AutoIncrStr() + " "
sql += d.AutoIncrStr() + " "
}
if col.Nullable {
@ -213,9 +235,6 @@ func (col *Column) String(engine *Engine) string {
sql += "DEFAULT " + col.Default + " "
}
if col.Comment != "" {
sql += "COMMENT '" + col.Comment + "' "
}
return sql
}

12
xorm.go
View File

@ -10,7 +10,7 @@ import (
)
const (
version string = "0.1.9"
version string = "0.2.0"
)
func close(engine *Engine) {
@ -24,19 +24,19 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
DataSourceName: dataSourceName, Filters: make([]Filter, 0)}
if driverName == SQLITE {
engine.Dialect = &sqlite3{}
engine.dialect = &sqlite3{}
} else if driverName == MYSQL {
engine.Dialect = &mysql{}
engine.dialect = &mysql{}
} else if driverName == POSTGRES {
engine.Dialect = &postgres{}
engine.dialect = &postgres{}
engine.Filters = append(engine.Filters, &PgSeqFilter{})
engine.Filters = append(engine.Filters, &QuoteFilter{})
} else if driverName == MYMYSQL {
engine.Dialect = &mymysql{}
engine.dialect = &mymysql{}
} else {
return nil, errors.New(fmt.Sprintf("Unsupported driver name: %v", driverName))
}
err := engine.Dialect.Init(dataSourceName)
err := engine.dialect.Init(driverName, dataSourceName)
if err != nil {
return nil, err
}

1
xorm/c++.go Normal file
View File

@ -0,0 +1 @@
package main

50
xorm/cmd.go Normal file
View File

@ -0,0 +1,50 @@
package main
import (
"fmt"
"os"
"strings"
)
// A Command is an implementation of a go command
// like go build or go fix.
type Command struct {
// Run runs the command.
// The args are the arguments after the command name.
Run func(cmd *Command, args []string)
// UsageLine is the one-line usage message.
// The first word in the line is taken to be the command name.
UsageLine string
// Short is the short description shown in the 'go help' output.
Short string
// Long is the long message shown in the 'go help <this-command>' output.
Long string
// Flag is a set of flags specific to this command.
Flags map[string]bool
}
// Name returns the command's name: the first word in the usage line.
func (c *Command) Name() string {
name := c.UsageLine
i := strings.Index(name, " ")
if i >= 0 {
name = name[:i]
}
return name
}
func (c *Command) Usage() {
fmt.Fprintf(os.Stderr, "usage: %s\n\n", c.UsageLine)
fmt.Fprintf(os.Stderr, "%s\n", strings.TrimSpace(c.Long))
os.Exit(2)
}
// Runnable reports whether the command can be run; otherwise
// it is a documentation pseudo-command such as importpath.
func (c *Command) Runnable() bool {
return c.Run != nil
}

43
xorm/go.go Normal file
View File

@ -0,0 +1,43 @@
package main
import (
//"github.com/lunny/xorm"
"strings"
"xorm"
)
func typestring(st xorm.SQLType) string {
t := xorm.SQLType2Type(st)
s := t.String()
if s == "[]uint8" {
return "[]byte"
}
return s
}
func tag(col *xorm.Column) string {
res := make([]string, 0)
if !col.Nullable {
res = append(res, "not null")
}
if col.IsPrimaryKey {
res = append(res, "pk")
}
if col.Default != "" {
res = append(res, "default "+col.Default)
}
if col.IsAutoIncrement {
res = append(res, "autoincr")
}
if col.IsCreated {
res = append(res, "created")
}
if col.IsUpdated {
res = append(res, "updated")
}
if len(res) > 0 {
return "`xorm:\"" + strings.Join(res, " ") + "\"`"
}
return ""
}

20
xorm/install.sh Executable file
View File

@ -0,0 +1,20 @@
#!/usr/bin/env bash
if [ ! -f install.sh ]; then
echo 'install must be run within its container folder' 1>&2
exit 1
fi
CURDIR=`pwd`
NEWPATH="$GOPATH/src/github.com/lunny/xorm/${PWD##*/}"
if [ ! -d "$NEWPATH" ]; then
ln -s $CURDIR $NEWPATH
fi
gofmt -w $CURDIR
cd $NEWPATH
go install ${PWD##*/}
cd $CURDIR
echo 'finished'

176
xorm/reverse.go Normal file
View File

@ -0,0 +1,176 @@
package main
import (
"fmt"
//"github.com/lunny/xorm"
"bytes"
_ "github.com/bylevel/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"go/format"
"io/ioutil"
"os"
"path"
"path/filepath"
"text/template"
"xorm"
)
var CmdReverse = &Command{
UsageLine: "reverse -m driverName datasourceName tmplpath",
Short: "reverse a db to codes",
Long: `
according database's tables and columns to generate codes for Go, C++ and etc.
`,
}
func init() {
CmdReverse.Run = runReverse
CmdReverse.Flags = map[string]bool{}
}
func printReversePrompt(flag string) {
}
type Tmpl struct {
Table *xorm.Table
Imports map[string]string
Model string
}
func runReverse(cmd *Command, args []string) {
if len(args) < 3 {
fmt.Println("no")
return
}
curPath, err := os.Getwd()
if err != nil {
fmt.Println(curPath)
return
}
var genDir string
var model string
if len(args) == 4 {
genDir, err = filepath.Abs(args[3])
if err != nil {
fmt.Println(err)
return
}
model = path.Base(genDir)
} else {
model = "model"
genDir = path.Join(curPath, model)
}
os.MkdirAll(genDir, os.ModePerm)
Orm, err := xorm.NewEngine(args[0], args[1])
if err != nil {
fmt.Println(err)
return
}
tables, err := Orm.DBMetas()
if err != nil {
fmt.Println(err)
return
}
dir, err := filepath.Abs(args[2])
if err != nil {
fmt.Println(curPath)
return
}
var isMultiFile bool = true
m := &xorm.SnakeMapper{}
filepath.Walk(dir, func(f string, info os.FileInfo, err error) error {
if info.IsDir() {
return nil
}
bs, err := ioutil.ReadFile(f)
if err != nil {
fmt.Println(err)
return err
}
t := template.New(f)
t.Funcs(template.FuncMap{"Mapper": m.Table2Obj,
"Type": typestring,
"Tag": tag,
})
tmpl, err := t.Parse(string(bs))
if err != nil {
fmt.Println(err)
return err
}
var w *os.File
fileName := info.Name()
newFileName := fileName[:len(fileName)-4]
ext := path.Ext(newFileName)
if !isMultiFile {
w, err = os.OpenFile(path.Join(genDir, newFileName), os.O_RDWR|os.O_CREATE, 0700)
if err != nil {
fmt.Println(err)
return err
}
}
for _, table := range tables {
// imports
imports := make(map[string]string)
for _, col := range table.Columns {
if typestring(col.SQLType) == "time.Time" {
imports["time.Time"] = "time.Time"
}
}
if isMultiFile {
w, err = os.OpenFile(path.Join(genDir, m.Table2Obj(table.Name)+ext), os.O_RDWR|os.O_CREATE, 0700)
if err != nil {
fmt.Println(err)
return err
}
}
newbytes := bytes.NewBufferString("")
t := &Tmpl{Table: table, Imports: imports, Model: model}
err = tmpl.Execute(newbytes, t)
if err != nil {
fmt.Println(err)
return err
}
tplcontent, err := ioutil.ReadAll(newbytes)
if err != nil {
fmt.Println(err)
return err
}
source, err := format.Source(tplcontent)
if err != nil {
fmt.Println(err)
return err
}
w.WriteString(string(source))
if isMultiFile {
w.Close()
}
}
if !isMultiFile {
w.Close()
}
return nil
})
}

View File

@ -0,0 +1,11 @@
package {{.Model}}
import (
"github.com/lunny/xorm"
{{range .Imports}}"{{.}}"{{end}}
)
type {{Mapper .Table.Name}} struct {
{{range .Table.Columns}} {{Mapper .Name}} {{Type .SQLType}} {{Tag .}}
{{end}}
}

158
xorm/xorm.go Normal file
View File

@ -0,0 +1,158 @@
package main
import (
"fmt"
"io"
"os"
"runtime"
"strings"
"sync"
"text/template"
"unicode"
"unicode/utf8"
)
// +build go1.1
// Test that go1.1 tag above is included in builds. main.go refers to this definition.
const go11tag = true
// Commands lists the available commands and help topics.
// The order here is the order in which they are printed by 'gopm help'.
var commands = []*Command{
CmdReverse,
}
func init() {
runtime.GOMAXPROCS(runtime.NumCPU())
}
func main() {
// Check length of arguments.
args := os.Args[1:]
if len(args) < 1 {
usage()
return
}
// Show help documentation.
if args[0] == "help" {
help(args[1:])
return
}
// Check commands and run.
for _, comm := range commands {
if comm.Name() == args[0] && comm.Run != nil {
comm.Run(comm, args[1:])
exit()
return
}
}
fmt.Fprintf(os.Stderr, "xorm: unknown subcommand %q\nRun 'xorm help' for usage.\n", args[0])
setExitStatus(2)
exit()
}
var exitStatus = 0
var exitMu sync.Mutex
func setExitStatus(n int) {
exitMu.Lock()
if exitStatus < n {
exitStatus = n
}
exitMu.Unlock()
}
var usageTemplate = `xorm is a database tool based xorm package.
Usage:
xorm command [arguments]
The commands are:
{{range .}}{{if .Runnable}}
{{.Name | printf "%-11s"}} {{.Short}}{{end}}{{end}}
Use "xorm help [command]" for more information about a command.
Additional help topics:
{{range .}}{{if not .Runnable}}
{{.Name | printf "%-11s"}} {{.Short}}{{end}}{{end}}
Use "xorm help [topic]" for more information about that topic.
`
var helpTemplate = `{{if .Runnable}}usage: go {{.UsageLine}}
{{end}}{{.Long | trim}}
`
// tmpl executes the given template text on data, writing the result to w.
func tmpl(w io.Writer, text string, data interface{}) {
t := template.New("top")
t.Funcs(template.FuncMap{"trim": strings.TrimSpace, "capitalize": capitalize})
template.Must(t.Parse(text))
if err := t.Execute(w, data); err != nil {
panic(err)
}
}
func capitalize(s string) string {
if s == "" {
return s
}
r, n := utf8.DecodeRuneInString(s)
return string(unicode.ToTitle(r)) + s[n:]
}
func printUsage(w io.Writer) {
tmpl(w, usageTemplate, commands)
}
func usage() {
printUsage(os.Stderr)
os.Exit(2)
}
// help implements the 'help' command.
func help(args []string) {
if len(args) == 0 {
printUsage(os.Stdout)
// not exit 2: succeeded at 'gopm help'.
return
}
if len(args) != 1 {
fmt.Fprintf(os.Stderr, "usage: xorm help command\n\nToo many arguments given.\n")
os.Exit(2) // failed at 'gopm help'
}
arg := args[0]
for _, cmd := range commands {
if cmd.Name() == arg {
tmpl(os.Stdout, helpTemplate, cmd)
// not exit 2: succeeded at 'gopm help cmd'.
return
}
}
fmt.Fprintf(os.Stderr, "Unknown help topic %#q. Run 'xorm help'.\n", arg)
os.Exit(2) // failed at 'gopm help cmd'
}
var atexitFuncs []func()
func atexit(f func()) {
atexitFuncs = append(atexitFuncs, f)
}
func exit() {
for _, f := range atexitFuncs {
f()
}
os.Exit(exitStatus)
}