2017-01-03 05:31:47 +00:00
// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"database/sql"
"fmt"
2020-02-19 13:16:54 +00:00
"regexp"
2017-01-03 05:31:47 +00:00
"strings"
2019-06-17 05:38:13 +00:00
"xorm.io/core"
2017-01-03 05:31:47 +00:00
)
// Ping test if database is ok
func ( session * Session ) Ping ( ) error {
2017-07-27 05:32:35 +00:00
if session . isAutoClose {
2017-01-03 05:31:47 +00:00
defer session . Close ( )
}
2017-08-20 09:05:42 +00:00
session . engine . logger . Infof ( "PING DATABASE %v" , session . engine . DriverName ( ) )
2019-01-20 03:01:14 +00:00
return session . DB ( ) . PingContext ( session . ctx )
2017-01-03 05:31:47 +00:00
}
// CreateTable create a table according a bean
func ( session * Session ) CreateTable ( bean interface { } ) error {
2017-08-20 09:05:42 +00:00
if session . isAutoClose {
defer session . Close ( )
}
return session . createTable ( bean )
}
func ( session * Session ) createTable ( bean interface { } ) error {
2018-04-10 01:50:29 +00:00
if err := session . statement . setRefBean ( bean ) ; err != nil {
2017-04-10 11:45:00 +00:00
return err
}
2017-01-03 05:31:47 +00:00
2017-08-20 09:05:42 +00:00
sqlStr := session . statement . genCreateTableSQL ( )
_ , err := session . exec ( sqlStr )
return err
}
// CreateIndexes create indexes
func ( session * Session ) CreateIndexes ( bean interface { } ) error {
2017-07-27 05:32:35 +00:00
if session . isAutoClose {
2017-01-03 05:31:47 +00:00
defer session . Close ( )
}
2017-08-20 09:05:42 +00:00
return session . createIndexes ( bean )
2017-01-03 05:31:47 +00:00
}
2017-08-20 09:05:42 +00:00
func ( session * Session ) createIndexes ( bean interface { } ) error {
2018-04-10 01:50:29 +00:00
if err := session . statement . setRefBean ( bean ) ; err != nil {
2017-04-10 11:45:00 +00:00
return err
}
2017-01-03 05:31:47 +00:00
2017-07-27 05:32:35 +00:00
sqls := session . statement . genIndexSQL ( )
2017-01-03 05:31:47 +00:00
for _ , sqlStr := range sqls {
_ , err := session . exec ( sqlStr )
if err != nil {
return err
}
}
return nil
}
// CreateUniques create uniques
func ( session * Session ) CreateUniques ( bean interface { } ) error {
2017-08-20 09:05:42 +00:00
if session . isAutoClose {
defer session . Close ( )
2017-04-10 11:45:00 +00:00
}
2017-08-20 09:05:42 +00:00
return session . createUniques ( bean )
}
2017-01-03 05:31:47 +00:00
2017-08-20 09:05:42 +00:00
func ( session * Session ) createUniques ( bean interface { } ) error {
2018-04-10 01:50:29 +00:00
if err := session . statement . setRefBean ( bean ) ; err != nil {
2017-08-20 09:05:42 +00:00
return err
2017-01-03 05:31:47 +00:00
}
2017-07-27 05:32:35 +00:00
sqls := session . statement . genUniqueSQL ( )
2017-01-03 05:31:47 +00:00
for _ , sqlStr := range sqls {
_ , err := session . exec ( sqlStr )
if err != nil {
return err
}
}
return nil
}
// DropIndexes drop indexes
func ( session * Session ) DropIndexes ( bean interface { } ) error {
2017-08-20 09:05:42 +00:00
if session . isAutoClose {
defer session . Close ( )
2017-04-10 11:45:00 +00:00
}
2017-01-03 05:31:47 +00:00
2017-08-20 09:05:42 +00:00
return session . dropIndexes ( bean )
}
func ( session * Session ) dropIndexes ( bean interface { } ) error {
2018-04-10 01:50:29 +00:00
if err := session . statement . setRefBean ( bean ) ; err != nil {
2017-08-20 09:05:42 +00:00
return err
2017-01-03 05:31:47 +00:00
}
2017-07-27 05:32:35 +00:00
sqls := session . statement . genDelIndexSQL ( )
2017-01-03 05:31:47 +00:00
for _ , sqlStr := range sqls {
_ , err := session . exec ( sqlStr )
if err != nil {
return err
}
}
return nil
}
2020-02-19 13:16:54 +00:00
// DropTableCols drop specify columns of a table
func ( session * Session ) DropTableCols ( beanOrTableName interface { } , cols ... string ) error {
if session . isAutoClose {
defer session . Close ( )
}
return session . dropTableCols ( beanOrTableName , cols )
}
func ( session * Session ) dropTableCols ( beanOrTableName interface { } , cols [ ] string ) error {
tableName := session . engine . TableName ( beanOrTableName )
if tableName == "" || len ( cols ) == 0 {
return nil
}
// TODO: This will not work if there are foreign keys
switch session . engine . dialect . DBType ( ) {
case core . SQLITE :
// First drop the indexes on the columns
res , errIndex := session . Query ( fmt . Sprintf ( "PRAGMA index_list(`%s`)" , tableName ) )
if errIndex != nil {
return errIndex
}
for _ , row := range res {
indexName := row [ "name" ]
indexRes , err := session . Query ( fmt . Sprintf ( "PRAGMA index_info(`%s`)" , indexName ) )
if err != nil {
return err
}
if len ( indexRes ) != 1 {
continue
}
indexColumn := string ( indexRes [ 0 ] [ "name" ] )
for _ , name := range cols {
if name == indexColumn {
_ , err := session . Exec ( fmt . Sprintf ( "DROP INDEX `%s`" , indexName ) )
if err != nil {
return err
}
}
}
}
// Here we need to get the columns from the original table
sql := fmt . Sprintf ( "SELECT sql FROM sqlite_master WHERE tbl_name='%s' and type='table'" , tableName )
res , err := session . Query ( sql )
if err != nil {
return err
}
tableSQL := string ( res [ 0 ] [ "sql" ] )
// Separate out the column definitions
tableSQL = tableSQL [ strings . Index ( tableSQL , "(" ) : ]
// Remove the required cols
for _ , name := range cols {
tableSQL = regexp . MustCompile ( regexp . QuoteMeta ( "`" + name + "`" ) + "[^`,)]*?[,)]" ) . ReplaceAllString ( tableSQL , "" )
}
// Ensure the query is ended properly
tableSQL = strings . TrimSpace ( tableSQL )
if tableSQL [ len ( tableSQL ) - 1 ] != ')' {
if tableSQL [ len ( tableSQL ) - 1 ] == ',' {
tableSQL = tableSQL [ : len ( tableSQL ) - 1 ]
}
tableSQL += ")"
}
// Find all the columns in the table
columns := regexp . MustCompile ( "`([^`]*)`" ) . FindAllString ( tableSQL , - 1 )
tableSQL = fmt . Sprintf ( "CREATE TABLE `new_%s_new` " , tableName ) + tableSQL
if _ , err := session . Exec ( tableSQL ) ; err != nil {
return err
}
// Now restore the data
columnsSeparated := strings . Join ( columns , "," )
insertSQL := fmt . Sprintf ( "INSERT INTO `new_%s_new` (%s) SELECT %s FROM %s" , tableName , columnsSeparated , columnsSeparated , tableName )
if _ , err := session . Exec ( insertSQL ) ; err != nil {
return err
}
// Now drop the old table
if _ , err := session . Exec ( fmt . Sprintf ( "DROP TABLE `%s`" , tableName ) ) ; err != nil {
return err
}
// Rename the table
if _ , err := session . Exec ( fmt . Sprintf ( "ALTER TABLE `new_%s_new` RENAME TO `%s`" , tableName , tableName ) ) ; err != nil {
return err
}
case core . POSTGRES :
columns := ""
for _ , col := range cols {
if columns != "" {
columns += ", "
}
columns += "DROP COLUMN `" + col + "` CASCADE"
}
if _ , err := session . Exec ( fmt . Sprintf ( "ALTER TABLE `%s` %s" , tableName , columns ) ) ; err != nil {
return fmt . Errorf ( "Drop table `%s` columns %v: %v" , tableName , cols , err )
}
case core . MYSQL :
// Drop indexes on columns first
sql := fmt . Sprintf ( "SHOW INDEX FROM %s WHERE column_name IN ('%s')" , tableName , strings . Join ( cols , "','" ) )
res , err := session . Query ( sql )
if err != nil {
return err
}
for _ , index := range res {
indexName := index [ "column_name" ]
if len ( indexName ) > 0 {
_ , err := session . Exec ( fmt . Sprintf ( "DROP INDEX `%s` ON `%s`" , indexName , tableName ) )
if err != nil {
return err
}
}
}
// Now drop the columns
columns := ""
for _ , col := range cols {
if columns != "" {
columns += ", "
}
columns += "DROP COLUMN `" + col + "`"
}
if _ , err := session . Exec ( fmt . Sprintf ( "ALTER TABLE `%s` %s" , tableName , columns ) ) ; err != nil {
return fmt . Errorf ( "Drop table `%s` columns %v: %v" , tableName , cols , err )
}
case core . MSSQL :
columns := ""
for _ , col := range cols {
if columns != "" {
columns += ", "
}
columns += "`" + strings . ToLower ( col ) + "`"
}
sql := fmt . Sprintf ( "SELECT Name FROM SYS.DEFAULT_CONSTRAINTS WHERE PARENT_OBJECT_ID = OBJECT_ID('%[1]s') AND PARENT_COLUMN_ID IN (SELECT column_id FROM sys.columns WHERE lower(NAME) IN (%[2]s) AND object_id = OBJECT_ID('%[1]s'))" ,
tableName , strings . Replace ( columns , "`" , "'" , - 1 ) )
constraints := make ( [ ] string , 0 )
if err := session . SQL ( sql ) . Find ( & constraints ) ; err != nil {
session . Rollback ( )
return fmt . Errorf ( "Find constraints: %v" , err )
}
for _ , constraint := range constraints {
if _ , err := session . Exec ( fmt . Sprintf ( "ALTER TABLE `%s` DROP CONSTRAINT `%s`" , tableName , constraint ) ) ; err != nil {
session . Rollback ( )
return fmt . Errorf ( "Drop table `%s` constraint `%s`: %v" , tableName , constraint , err )
}
}
if _ , err := session . Exec ( fmt . Sprintf ( "ALTER TABLE `%s` DROP COLUMN %s" , tableName , columns ) ) ; err != nil {
session . Rollback ( )
return fmt . Errorf ( "Drop table `%s` columns %v: %v" , tableName , cols , err )
}
return session . Commit ( )
case core . ORACLE :
return fmt . Errorf ( "not implemented for oracle" )
default :
return fmt . Errorf ( "unrecognized DB" )
}
return nil
}
2017-01-03 05:31:47 +00:00
// DropTable drop table will drop table if exist, if drop failed, it will return error
func ( session * Session ) DropTable ( beanOrTableName interface { } ) error {
2017-08-20 09:05:42 +00:00
if session . isAutoClose {
defer session . Close ( )
}
return session . dropTable ( beanOrTableName )
}
func ( session * Session ) dropTable ( beanOrTableName interface { } ) error {
2018-05-04 02:51:34 +00:00
tableName := session . engine . TableName ( beanOrTableName )
2017-01-03 05:31:47 +00:00
var needDrop = true
2017-07-27 05:32:35 +00:00
if ! session . engine . dialect . SupportDropIfExists ( ) {
sqlStr , args := session . engine . dialect . TableCheckSql ( tableName )
2017-08-27 07:50:43 +00:00
results , err := session . queryBytes ( sqlStr , args ... )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
needDrop = len ( results ) > 0
}
if needDrop {
2018-04-10 01:50:29 +00:00
sqlStr := session . engine . Dialect ( ) . DropTableSql ( session . engine . TableName ( tableName , true ) )
_ , err := session . exec ( sqlStr )
2017-01-03 05:31:47 +00:00
return err
}
return nil
}
// IsTableExist if a table is exist
func ( session * Session ) IsTableExist ( beanOrTableName interface { } ) ( bool , error ) {
2017-08-20 09:05:42 +00:00
if session . isAutoClose {
defer session . Close ( )
}
2018-05-04 02:51:34 +00:00
tableName := session . engine . TableName ( beanOrTableName )
2017-01-03 05:31:47 +00:00
return session . isTableExist ( tableName )
}
func ( session * Session ) isTableExist ( tableName string ) ( bool , error ) {
2017-07-27 05:32:35 +00:00
sqlStr , args := session . engine . dialect . TableCheckSql ( tableName )
2017-08-27 07:50:43 +00:00
results , err := session . queryBytes ( sqlStr , args ... )
2017-01-03 05:31:47 +00:00
return len ( results ) > 0 , err
}
// IsTableEmpty if table have any records
func ( session * Session ) IsTableEmpty ( bean interface { } ) ( bool , error ) {
2018-04-10 01:50:29 +00:00
if session . isAutoClose {
defer session . Close ( )
2017-01-03 05:31:47 +00:00
}
2018-05-04 02:51:34 +00:00
return session . isTableEmpty ( session . engine . TableName ( bean ) )
2017-01-03 05:31:47 +00:00
}
func ( session * Session ) isTableEmpty ( tableName string ) ( bool , error ) {
var total int64
2018-04-10 01:50:29 +00:00
sqlStr := fmt . Sprintf ( "select count(*) from %s" , session . engine . Quote ( session . engine . TableName ( tableName , true ) ) )
2017-08-27 07:50:43 +00:00
err := session . queryRow ( sqlStr ) . Scan ( & total )
2017-01-03 05:31:47 +00:00
if err != nil {
if err == sql . ErrNoRows {
err = nil
}
return true , err
}
return total == 0 , nil
}
// find if index is exist according cols
func ( session * Session ) isIndexExist2 ( tableName string , cols [ ] string , unique bool ) ( bool , error ) {
2017-07-27 05:32:35 +00:00
indexes , err := session . engine . dialect . GetIndexes ( tableName )
2017-01-03 05:31:47 +00:00
if err != nil {
return false , err
}
for _ , index := range indexes {
if sliceEq ( index . Cols , cols ) {
if unique {
return index . Type == core . UniqueType , nil
}
return index . Type == core . IndexType , nil
}
}
return false , nil
}
func ( session * Session ) addColumn ( colName string ) error {
2017-07-27 05:32:35 +00:00
col := session . statement . RefTable . GetColumn ( colName )
sql , args := session . statement . genAddColumnStr ( col )
2017-01-03 05:31:47 +00:00
_ , err := session . exec ( sql , args ... )
return err
}
func ( session * Session ) addIndex ( tableName , idxName string ) error {
2017-07-27 05:32:35 +00:00
index := session . statement . RefTable . Indexes [ idxName ]
sqlStr := session . engine . dialect . CreateIndexSql ( tableName , index )
2017-01-03 05:31:47 +00:00
_ , err := session . exec ( sqlStr )
return err
}
func ( session * Session ) addUnique ( tableName , uqeName string ) error {
2017-07-27 05:32:35 +00:00
index := session . statement . RefTable . Indexes [ uqeName ]
sqlStr := session . engine . dialect . CreateIndexSql ( tableName , index )
2017-01-03 05:31:47 +00:00
_ , err := session . exec ( sqlStr )
return err
}
// Sync2 synchronize structs to database tables
func ( session * Session ) Sync2 ( beans ... interface { } ) error {
2017-07-27 05:32:35 +00:00
engine := session . engine
2017-01-03 05:31:47 +00:00
2017-08-20 09:05:42 +00:00
if session . isAutoClose {
session . isAutoClose = false
defer session . Close ( )
}
2019-09-30 08:32:57 +00:00
tables , err := engine . dialect . GetTables ( )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
2018-03-05 10:01:47 +00:00
session . autoResetStatement = false
defer func ( ) {
session . autoResetStatement = true
session . resetStatement ( )
} ( )
2017-01-03 05:31:47 +00:00
for _ , bean := range beans {
v := rValue ( bean )
2017-03-30 02:39:38 +00:00
table , err := engine . mapType ( v )
if err != nil {
return err
}
2019-10-02 04:37:53 +00:00
var tbName string
if len ( session . statement . AltTableName ) > 0 {
tbName = session . statement . AltTableName
} else {
tbName = engine . TableName ( bean )
}
tbNameWithSchema := engine . tbNameWithSchema ( tbName )
2017-01-03 05:31:47 +00:00
var oriTable * core . Table
for _ , tb := range tables {
2019-10-02 04:37:53 +00:00
if strings . EqualFold ( engine . tbNameWithSchema ( tb . Name ) , engine . tbNameWithSchema ( tbName ) ) {
2017-01-03 05:31:47 +00:00
oriTable = tb
break
}
}
2019-09-30 08:32:57 +00:00
// this is a new table
2017-01-03 05:31:47 +00:00
if oriTable == nil {
2017-08-20 09:05:42 +00:00
err = session . StoreEngine ( session . statement . StoreEngine ) . createTable ( bean )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
2017-08-20 09:05:42 +00:00
err = session . createUniques ( bean )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
2017-08-20 09:05:42 +00:00
err = session . createIndexes ( bean )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
2019-09-30 08:32:57 +00:00
continue
}
2017-01-03 05:31:47 +00:00
2019-09-30 08:32:57 +00:00
// this will modify an old table
if err = engine . loadTableInfo ( oriTable ) ; err != nil {
return err
}
// check columns
for _ , col := range table . Columns ( ) {
var oriCol * core . Column
for _ , col2 := range oriTable . Columns ( ) {
if strings . EqualFold ( col . Name , col2 . Name ) {
oriCol = col2
break
2017-01-03 05:31:47 +00:00
}
2019-09-30 08:32:57 +00:00
}
// column is not exist on table
if oriCol == nil {
session . statement . RefTable = table
session . statement . tableName = tbNameWithSchema
if err = session . addColumn ( col . Name ) ; err != nil {
2017-01-03 05:31:47 +00:00
return err
}
2019-09-30 08:32:57 +00:00
continue
2017-01-03 05:31:47 +00:00
}
2019-09-30 08:32:57 +00:00
err = nil
expectedType := engine . dialect . SqlType ( col )
curType := engine . dialect . SqlType ( oriCol )
if expectedType != curType {
if expectedType == core . Text &&
strings . HasPrefix ( curType , core . Varchar ) {
// currently only support mysql & postgres
if engine . dialect . DBType ( ) == core . MYSQL ||
engine . dialect . DBType ( ) == core . POSTGRES {
engine . logger . Infof ( "Table %s column %s change type from %s to %s\n" ,
tbNameWithSchema , col . Name , curType , expectedType )
_ , err = session . exec ( engine . dialect . ModifyColumnSql ( tbNameWithSchema , col ) )
} else {
engine . logger . Warnf ( "Table %s column %s db type is %s, struct type is %s\n" ,
tbNameWithSchema , col . Name , curType , expectedType )
2017-01-03 05:31:47 +00:00
}
2019-09-30 08:32:57 +00:00
} else if strings . HasPrefix ( curType , core . Varchar ) && strings . HasPrefix ( expectedType , core . Varchar ) {
if engine . dialect . DBType ( ) == core . MYSQL {
if oriCol . Length < col . Length {
engine . logger . Infof ( "Table %s column %s change type from varchar(%d) to varchar(%d)\n" ,
tbNameWithSchema , col . Name , oriCol . Length , col . Length )
_ , err = session . exec ( engine . dialect . ModifyColumnSql ( tbNameWithSchema , col ) )
2017-01-03 05:31:47 +00:00
}
2019-09-30 08:32:57 +00:00
}
} else {
if ! ( strings . HasPrefix ( curType , expectedType ) && curType [ len ( expectedType ) ] == '(' ) {
engine . logger . Warnf ( "Table %s column %s db type is %s, struct type is %s" ,
tbNameWithSchema , col . Name , curType , expectedType )
2017-01-03 05:31:47 +00:00
}
}
2019-09-30 08:32:57 +00:00
} else if expectedType == core . Varchar {
if engine . dialect . DBType ( ) == core . MYSQL {
if oriCol . Length < col . Length {
engine . logger . Infof ( "Table %s column %s change type from varchar(%d) to varchar(%d)\n" ,
tbNameWithSchema , col . Name , oriCol . Length , col . Length )
_ , err = session . exec ( engine . dialect . ModifyColumnSql ( tbNameWithSchema , col ) )
}
2017-01-03 05:31:47 +00:00
}
}
2019-10-02 07:04:49 +00:00
2019-09-30 08:32:57 +00:00
if col . Default != oriCol . Default {
2019-10-02 07:04:49 +00:00
if ( col . SQLType . Name == core . Bool || col . SQLType . Name == core . Boolean ) &&
( ( strings . EqualFold ( col . Default , "true" ) && oriCol . Default == "1" ) ||
( strings . EqualFold ( col . Default , "false" ) && oriCol . Default == "0" ) ) {
} else {
engine . logger . Warnf ( "Table %s Column %s db default is %s, struct default is %s" ,
tbName , col . Name , oriCol . Default , col . Default )
}
2019-09-30 08:32:57 +00:00
}
if col . Nullable != oriCol . Nullable {
engine . logger . Warnf ( "Table %s Column %s db nullable is %v, struct nullable is %v" ,
tbName , col . Name , oriCol . Nullable , col . Nullable )
}
if err != nil {
return err
}
}
var foundIndexNames = make ( map [ string ] bool )
var addedNames = make ( map [ string ] * core . Index )
2017-01-03 05:31:47 +00:00
2019-09-30 08:32:57 +00:00
for name , index := range table . Indexes {
var oriIndex * core . Index
2017-01-03 05:31:47 +00:00
for name2 , index2 := range oriTable . Indexes {
2019-09-30 08:32:57 +00:00
if index . Equal ( index2 ) {
oriIndex = index2
foundIndexNames [ name2 ] = true
break
}
}
if oriIndex != nil {
if oriIndex . Type != index . Type {
sql := engine . dialect . DropIndexSql ( tbNameWithSchema , oriIndex )
2017-08-20 09:05:42 +00:00
_ , err = session . exec ( sql )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
2019-09-30 08:32:57 +00:00
oriIndex = nil
2017-01-03 05:31:47 +00:00
}
}
2019-09-30 08:32:57 +00:00
if oriIndex == nil {
addedNames [ name ] = index
}
}
for name2 , index2 := range oriTable . Indexes {
if _ , ok := foundIndexNames [ name2 ] ; ! ok {
sql := engine . dialect . DropIndexSql ( tbNameWithSchema , index2 )
_ , err = session . exec ( sql )
2017-01-03 05:31:47 +00:00
if err != nil {
return err
}
}
}
2019-09-30 08:32:57 +00:00
for name , index := range addedNames {
if index . Type == core . UniqueType {
session . statement . RefTable = table
session . statement . tableName = tbNameWithSchema
err = session . addUnique ( tbNameWithSchema , name )
} else if index . Type == core . IndexType {
session . statement . RefTable = table
session . statement . tableName = tbNameWithSchema
err = session . addIndex ( tbNameWithSchema , name )
}
if err != nil {
return err
2017-01-03 05:31:47 +00:00
}
}
2019-09-30 08:32:57 +00:00
// check all the columns which removed from struct fields but left on database tables.
for _ , colName := range oriTable . ColumnsSeq ( ) {
if table . GetColumn ( colName ) == nil {
engine . logger . Warnf ( "Table %s has column %s but struct has not related field" , engine . TableName ( oriTable . Name , true ) , colName )
2017-01-03 05:31:47 +00:00
}
}
}
2019-09-30 08:32:57 +00:00
2017-01-03 05:31:47 +00:00
return nil
}