From 24a6ab1609a70e132db69b82492ca82ebd9725a1 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Sat, 25 Mar 2017 11:10:23 -0300 Subject: [PATCH] Add migrate package for schema versioned migrations (#549) This is a port from https://github.com/go-gormigrate/gormigrate --- .gitignore | 3 +- migrate/migrate.go | 214 ++++++++++++++++++++++++++++++++++++++++ migrate/migrate_test.go | 150 ++++++++++++++++++++++++++++ 3 files changed, 366 insertions(+), 1 deletion(-) create mode 100644 migrate/migrate.go create mode 100644 migrate/migrate_test.go diff --git a/.gitignore b/.gitignore index fa31dd78..f2c9a70a 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ vendor .vendor temp_test.go .vscode -xorm.test \ No newline at end of file +xorm.test +*.sqlite3 diff --git a/migrate/migrate.go b/migrate/migrate.go new file mode 100644 index 00000000..6c2a13a8 --- /dev/null +++ b/migrate/migrate.go @@ -0,0 +1,214 @@ +package migrate + +import ( + "errors" + "fmt" + + "github.com/go-xorm/xorm" +) + +// MigrateFunc is the func signature for migrating. +type MigrateFunc func(*xorm.Engine) error + +// RollbackFunc is the func signature for rollbacking. +type RollbackFunc func(*xorm.Engine) error + +// InitSchemaFunc is the func signature for initializing the schema. +type InitSchemaFunc func(*xorm.Engine) error + +// Options define options for all migrations. +type Options struct { + // TableName is the migration table. + TableName string + // IDColumnName is the name of column where the migration id will be stored. + IDColumnName string +} + +// Migration represents a database migration (a modification to be made on the database). +type Migration struct { + // ID is the migration identifier. Usually a timestamp like "201601021504". + ID string + // Migrate is a function that will br executed while running this migration. + Migrate MigrateFunc + // Rollback will be executed on rollback. Can be nil. + Rollback RollbackFunc +} + +// Migrate represents a collection of all migrations of a database schema. +type Migrate struct { + db *xorm.Engine + options *Options + migrations []*Migration + initSchema InitSchemaFunc +} + +var ( + // DefaultOptions can be used if you don't want to think about options. + DefaultOptions = &Options{ + TableName: "migrations", + IDColumnName: "id", + } + + // ErrRollbackImpossible is returned when trying to rollback a migration + // that has no rollback function. + ErrRollbackImpossible = errors.New("It's impossible to rollback this migration") + + // ErrNoMigrationDefined is returned when no migration is defined. + ErrNoMigrationDefined = errors.New("No migration defined") + + // ErrMissingID is returned when the ID od migration is equal to "" + ErrMissingID = errors.New("Missing ID in migration") + + // ErrNoRunnedMigration is returned when any runned migration was found while + // running RollbackLast + ErrNoRunnedMigration = errors.New("Could not find last runned migration") +) + +// New returns a new Gormigrate. +func New(db *xorm.Engine, options *Options, migrations []*Migration) *Migrate { + return &Migrate{ + db: db, + options: options, + migrations: migrations, + } +} + +// InitSchema sets a function that is run if no migration is found. +// The idea is preventing to run all migrations when a new clean database +// is being migrating. In this function you should create all tables and +// foreign key necessary to your application. +func (m *Migrate) InitSchema(initSchema InitSchemaFunc) { + m.initSchema = initSchema +} + +// Migrate executes all migrations that did not run yet. +func (m *Migrate) Migrate() error { + if err := m.createMigrationTableIfNotExists(); err != nil { + return err + } + + if m.initSchema != nil && m.isFirstRun() { + if err := m.runInitSchema(); err != nil { + return err + } + return nil + } + + for _, migration := range m.migrations { + if err := m.runMigration(migration); err != nil { + return err + } + } + return nil +} + +// RollbackLast undo the last migration +func (m *Migrate) RollbackLast() error { + if len(m.migrations) == 0 { + return ErrNoMigrationDefined + } + + lastRunnedMigration, err := m.getLastRunnedMigration() + if err != nil { + return err + } + + if err := m.RollbackMigration(lastRunnedMigration); err != nil { + return err + } + return nil +} + +func (m *Migrate) getLastRunnedMigration() (*Migration, error) { + for i := len(m.migrations) - 1; i >= 0; i-- { + migration := m.migrations[i] + if m.migrationDidRun(migration) { + return migration, nil + } + } + return nil, ErrNoRunnedMigration +} + +// RollbackMigration undo a migration. +func (m *Migrate) RollbackMigration(mig *Migration) error { + if mig.Rollback == nil { + return ErrRollbackImpossible + } + + if err := mig.Rollback(m.db); err != nil { + return err + } + + sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName) + if _, err := m.db.Exec(sql, mig.ID); err != nil { + return err + } + return nil +} + +func (m *Migrate) runInitSchema() error { + if err := m.initSchema(m.db); err != nil { + return err + } + + for _, migration := range m.migrations { + if err := m.insertMigration(migration.ID); err != nil { + return err + } + } + + return nil +} + +func (m *Migrate) runMigration(migration *Migration) error { + if len(migration.ID) == 0 { + return ErrMissingID + } + + if !m.migrationDidRun(migration) { + if err := migration.Migrate(m.db); err != nil { + return err + } + + if err := m.insertMigration(migration.ID); err != nil { + return err + } + } + return nil +} + +func (m *Migrate) createMigrationTableIfNotExists() error { + exists, err := m.db.IsTableExist(m.options.TableName) + if err != nil { + return err + } + if exists { + return nil + } + + sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(255) PRIMARY KEY)", m.options.TableName, m.options.IDColumnName) + if _, err := m.db.Exec(sql); err != nil { + return err + } + return nil +} + +func (m *Migrate) migrationDidRun(mig *Migration) bool { + row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s = ?", m.options.TableName, m.options.IDColumnName), mig.ID) + var count int + row.Scan(&count) + return count > 0 +} + +func (m *Migrate) isFirstRun() bool { + row := m.db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", m.options.TableName)) + var count int + row.Scan(&count) + return count == 0 +} + +func (m *Migrate) insertMigration(id string) error { + sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?)", m.options.TableName, m.options.IDColumnName) + _, err := m.db.Exec(sql, id) + return err +} diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go new file mode 100644 index 00000000..086707ec --- /dev/null +++ b/migrate/migrate_test.go @@ -0,0 +1,150 @@ +package migrate + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/go-xorm/xorm" + _ "github.com/mattn/go-sqlite3" + "gopkg.in/stretchr/testify.v1/assert" +) + +type Person struct { + ID int64 + Name string +} + +type Pet struct { + ID int64 + Name string + PersonID int +} + +const ( + dbName = "testdb.sqlite3" +) + +var ( + migrations = []*Migration{ + { + ID: "201608301400", + Migrate: func(tx *xorm.Engine) error { + return tx.Sync2(&Person{}) + }, + Rollback: func(tx *xorm.Engine) error { + return tx.DropTables(&Person{}) + }, + }, + { + ID: "201608301430", + Migrate: func(tx *xorm.Engine) error { + return tx.Sync2(&Pet{}) + }, + Rollback: func(tx *xorm.Engine) error { + return tx.DropTables(&Pet{}) + }, + }, + } +) + +func TestMigration(t *testing.T) { + _ = os.Remove(dbName) + + db, err := xorm.NewEngine("sqlite3", dbName) + if err != nil { + log.Fatal(err) + } + defer db.Close() + + if err = db.DB().Ping(); err != nil { + log.Fatal(err) + } + + m := New(db, DefaultOptions, migrations) + + err = m.Migrate() + assert.NoError(t, err) + exists, _ := db.IsTableExist(&Person{}) + assert.True(t, exists) + exists, _ = db.IsTableExist(&Pet{}) + assert.True(t, exists) + assert.Equal(t, 2, tableCount(db, "migrations")) + + err = m.RollbackLast() + assert.NoError(t, err) + exists, _ = db.IsTableExist(&Person{}) + assert.True(t, exists) + exists, _ = db.IsTableExist(&Pet{}) + assert.False(t, exists) + assert.Equal(t, 1, tableCount(db, "migrations")) + + err = m.RollbackLast() + assert.NoError(t, err) + exists, _ = db.IsTableExist(&Person{}) + assert.False(t, exists) + exists, _ = db.IsTableExist(&Pet{}) + assert.False(t, exists) + assert.Equal(t, 0, tableCount(db, "migrations")) +} + +func TestInitSchema(t *testing.T) { + os.Remove(dbName) + + db, err := xorm.NewEngine("sqlite3", dbName) + if err != nil { + log.Fatal(err) + } + defer db.Close() + if err = db.DB().Ping(); err != nil { + log.Fatal(err) + } + + m := New(db, DefaultOptions, migrations) + m.InitSchema(func(tx *xorm.Engine) error { + if err := tx.Sync2(&Person{}); err != nil { + return err + } + if err := tx.Sync2(&Pet{}); err != nil { + return err + } + return nil + }) + + err = m.Migrate() + assert.NoError(t, err) + exists, _ := db.IsTableExist(&Person{}) + assert.True(t, exists) + exists, _ = db.IsTableExist(&Pet{}) + assert.True(t, exists) + assert.Equal(t, 2, tableCount(db, "migrations")) +} + +func TestMissingID(t *testing.T) { + os.Remove(dbName) + + db, err := xorm.NewEngine("sqlite3", dbName) + assert.NoError(t, err) + if db != nil { + defer db.Close() + } + assert.NoError(t, db.DB().Ping()) + + migrationsMissingID := []*Migration{ + { + Migrate: func(tx *xorm.Engine) error { + return nil + }, + }, + } + + m := New(db, DefaultOptions, migrationsMissingID) + assert.Equal(t, ErrMissingID, m.Migrate()) +} + +func tableCount(db *xorm.Engine, tableName string) (count int) { + row := db.DB().QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)) + row.Scan(&count) + return +}