diff --git a/mssql_dialect.go b/mssql_dialect.go index 0b85f6a9..4f5b5eb7 100644 --- a/mssql_dialect.go +++ b/mssql_dialect.go @@ -17,8 +17,8 @@ type mssql struct { core.Base } -func (db *mssql) Init(uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(db, uri, drivername, dataSourceName) +func (db *mssql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { + return db.Base.Init(d, db, uri, drivername, dataSourceName) } func (db *mssql) SqlType(c *core.Column) string { @@ -123,13 +123,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale from sys.columns a left join sys.types b on a.user_type_id=b.user_type_id where a.object_id=object_id('` + tableName + `')` - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, nil, err } @@ -183,12 +178,8 @@ where a.object_id=object_id('` + tableName + `')` func (db *mssql) GetTables() ([]*core.Table, error) { args := []interface{}{} s := `select name from sysobjects where xtype ='U'` - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } @@ -223,12 +214,7 @@ INNER JOIN SYS.COLUMNS C ON IXS.OBJECT_ID=C.OBJECT_ID AND IXCS.COLUMN_ID=C.COLUMN_ID WHERE IXS.TYPE_DESC='NONCLUSTERED' and OBJECT_NAME(IXS.OBJECT_ID) =? ` - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } diff --git a/mysql_dialect.go b/mysql_dialect.go index 58a30c44..71273183 100644 --- a/mysql_dialect.go +++ b/mysql_dialect.go @@ -28,8 +28,8 @@ type mysql struct { clientFoundRows bool } -func (db *mysql) Init(uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(db, uri, drivername, dataSourceName) +func (db *mysql) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { + return db.Base.Init(d, db, uri, drivername, dataSourceName) } func (db *mysql) SqlType(c *core.Column) string { @@ -114,13 +114,8 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column args := []interface{}{db.DbName, tableName} s := "SELECT `COLUMN_NAME`, `IS_NULLABLE`, `COLUMN_DEFAULT`, `COLUMN_TYPE`," + " `COLUMN_KEY`, `EXTRA` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, nil, err } @@ -198,15 +193,12 @@ func (db *mysql) GetColumns(tableName string) ([]string, map[string]*core.Column func (db *mysql) GetTables() ([]*core.Table, error) { args := []interface{}{db.DbName} s := "SELECT `TABLE_NAME`, `ENGINE`, `TABLE_ROWS`, `AUTO_INCREMENT` from `INFORMATION_SCHEMA`.`TABLES` WHERE `TABLE_SCHEMA`=?" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } + defer rows.Close() tables := make([]*core.Table, 0) for rows.Next() { @@ -227,15 +219,12 @@ func (db *mysql) GetTables() ([]*core.Table, error) { func (db *mysql) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{db.DbName, tableName} s := "SELECT `INDEX_NAME`, `NON_UNIQUE`, `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ?" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } + defer rows.Close() indexes := make(map[string]*core.Index, 0) for rows.Next() { diff --git a/oracle_dialect.go b/oracle_dialect.go index febd318e..36a2f62d 100644 --- a/oracle_dialect.go +++ b/oracle_dialect.go @@ -17,8 +17,8 @@ type oracle struct { core.Base } -func (db *oracle) Init(uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(db, uri, drivername, dataSourceName) +func (db *oracle) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { + return db.Base.Init(d, db, uri, drivername, dataSourceName) } func (db *oracle) SqlType(c *core.Column) string { @@ -98,12 +98,7 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum s := "SELECT column_name,data_default,data_type,data_length,data_precision,data_scale," + "nullable FROM USER_TAB_COLUMNS WHERE table_name = :1" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, nil, err } @@ -166,15 +161,12 @@ func (db *oracle) GetColumns(tableName string) ([]string, map[string]*core.Colum func (db *oracle) GetTables() ([]*core.Table, error) { args := []interface{}{} s := "SELECT table_name FROM user_tables" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } + defer rows.Close() tables := make([]*core.Table, 0) for rows.Next() { @@ -194,12 +186,7 @@ func (db *oracle) GetIndexes(tableName string) (map[string]*core.Index, error) { s := "SELECT t.column_name,i.uniqueness,i.index_name FROM user_ind_columns t,user_indexes i " + "WHERE t.index_name = i.index_name and t.table_name = i.table_name and t.table_name =:1" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } diff --git a/postgres_dialect.go b/postgres_dialect.go index 51269d53..943039e5 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -17,8 +17,8 @@ type postgres struct { core.Base } -func (db *postgres) Init(uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(db, uri, drivername, dataSourceName) +func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { + return db.Base.Init(d, db, uri, drivername, dataSourceName) } func (db *postgres) SqlType(c *core.Column) string { @@ -112,15 +112,13 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col args := []interface{}{tableName} s := "SELECT column_name, column_default, is_nullable, data_type, character_maximum_length" + ", numeric_precision, numeric_precision_radix FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, nil, err } + defer rows.Close() + cols := make(map[string]*core.Column) colSeq := make([]string, 0) @@ -200,15 +198,12 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col func (db *postgres) GetTables() ([]*core.Table, error) { args := []interface{}{} s := "SELECT tablename FROM pg_tables where schemaname = 'public'" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } + defer rows.Close() tables := make([]*core.Table, 0) for rows.Next() { @@ -228,15 +223,11 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) args := []interface{}{tableName} s := "SELECT indexname, indexdef FROM pg_indexes WHERE schemaname = 'public' and tablename = $1" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } + defer rows.Close() indexes := make(map[string]*core.Index, 0) for rows.Next() { diff --git a/session.go b/session.go index c67c46b1..881977a9 100644 --- a/session.go +++ b/session.go @@ -155,11 +155,6 @@ func (session *Session) NoCascade() *Session { return session } -/* -func (session *Session) MustCols(columns ...string) *Session { - session.Statement.Must() -}*/ - // Xorm automatically retrieve condition according struct, but // if struct has bool field, it will ignore them. So use UseBool // to tell system to do not ignore them. @@ -2443,11 +2438,12 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val switch k { case reflect.Bool: - if fieldValue.Bool() { + return fieldValue.Bool(), nil + /*if fieldValue.Bool() { return 1, nil } else { return 0, nil - } + }*/ case reflect.String: return fieldValue.String(), nil case reflect.Struct: diff --git a/sqlite3_dialect.go b/sqlite3_dialect.go index d5be5c72..0e19f96c 100644 --- a/sqlite3_dialect.go +++ b/sqlite3_dialect.go @@ -14,8 +14,8 @@ type sqlite3 struct { core.Base } -func (db *sqlite3) Init(uri *core.Uri, drivername, dataSourceName string) error { - return db.Base.Init(db, uri, drivername, dataSourceName) +func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error { + return db.Base.Init(d, db, uri, drivername, dataSourceName) } func (db *sqlite3) SqlType(c *core.Column) string { @@ -87,13 +87,8 @@ func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interfac func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, nil, err } @@ -147,12 +142,7 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) { args := []interface{}{} s := "SELECT name FROM sqlite_master WHERE type='table'" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } @@ -176,12 +166,8 @@ func (db *sqlite3) GetTables() ([]*core.Table, error) { func (db *sqlite3) GetIndexes(tableName string) (map[string]*core.Index, error) { args := []interface{}{tableName} s := "SELECT sql FROM sqlite_master WHERE type='index' and tbl_name = ?" - cnn, err := core.Open(db.DriverName(), db.DataSourceName()) - if err != nil { - return nil, err - } - defer cnn.Close() - rows, err := cnn.Query(s, args...) + + rows, err := db.DB().Query(s, args...) if err != nil { return nil, err } diff --git a/xorm.go b/xorm.go index e0f43c2e..53615564 100644 --- a/xorm.go +++ b/xorm.go @@ -4,11 +4,12 @@ import ( "database/sql" "errors" "fmt" - "github.com/go-xorm/core" "os" "reflect" "runtime" "sync" + + "github.com/go-xorm/core" ) const ( @@ -75,12 +76,12 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) { return nil, errors.New(fmt.Sprintf("Unsupported dialect type: %v", uri.DbType)) } - err = dialect.Init(uri, driverName, dataSourceName) + db, err := core.Open(driverName, dataSourceName) if err != nil { return nil, err } - db, err := core.OpenDialect(dialect) + err = dialect.Init(db, uri, driverName, dataSourceName) if err != nil { return nil, err }