diff --git a/engine.go b/engine.go index 46180438..2bd8d523 100644 --- a/engine.go +++ b/engine.go @@ -134,8 +134,8 @@ func (engine *Engine) NoCascade() *Session { // Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) { v := rValue(bean) - engine.autoMapType(v) - engine.Tables[v.Type()].Cacher = cacher + tb := engine.autoMapType(v) + tb.Cacher = cacher } // NewDB provides an interface to operate database directly @@ -483,7 +483,7 @@ func (engine *Engine) Desc(colNames ...string) *Session { return session.Desc(colNames...) } -// Method Asc will generate "ORDER BY column1 DESC, column2 Asc" +// Method Asc will generate "ORDER BY column1,column2 Asc" // This method can chainable use. // // engine.Desc("name").Asc("age").Find(&users) @@ -587,6 +587,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { for i := 0; i < t.NumField(); i++ { tag := t.Field(i).Tag + ormTagStr := tag.Get(engine.TagIdentifier) var col *core.Column fieldValue := v.Field(i) @@ -1012,6 +1013,138 @@ func (engine *Engine) Sync(beans ...interface{}) error { return nil } +func (engine *Engine) Sync2(beans ...interface{}) error { + tables, err := engine.DBMetas() + if err != nil { + return err + } + + for _, bean := range beans { + table := engine.autoMap(bean) + + var oriTable *core.Table + for _, tb := range tables { + if tb.Name == table.Name { + oriTable = tb + break + } + } + + if oriTable == nil { + err = engine.CreateTables(bean) + if err != nil { + return err + } + + err = engine.CreateUniques(bean) + if err != nil { + return err + } + + err = engine.CreateIndexes(bean) + if err != nil { + return err + } + } else { + for _, col := range table.Columns() { + var oriCol *core.Column + for _, col2 := range oriTable.Columns() { + if col.Name == col2.Name { + oriCol = col2 + break + } + } + + if oriCol != nil { + if col.SQLType.Name != oriCol.SQLType.Name { + if col.SQLType.Name == core.Text && + oriCol.SQLType.Name == core.Varchar { + // currently only support mysql + if engine.dialect.DBType() == core.MYSQL { + _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + } else { + engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + } + } else { + engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + } + } + if col.Default != oriCol.Default { + engine.LogWarn("Table %s Column %s Old default is %s, new default is %s", + table.Name, col.Name, oriCol.Default, col.Default) + } + if col.Nullable != oriCol.Nullable { + engine.LogWarn("Table %s Column %s Old nullable is %v, new nullable is %v", + table.Name, col.Name, oriCol.Nullable, col.Nullable) + } + } else { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addColumn(col.Name) + } + if err != nil { + return err + } + } + + var foundIndexNames = make(map[string]bool) + + for name, index := range table.Indexes { + var oriIndex *core.Index + for name2, index2 := range oriTable.Indexes { + if index.Equal(index2) { + oriIndex = index2 + foundIndexNames[name2] = true + break + } + } + + if oriIndex != nil { + if oriIndex.Type != index.Type { + sql := engine.dialect.DropIndexSql(table.Name, oriIndex) + _, err = engine.Exec(sql) + if err != nil { + return err + } + oriIndex = nil + } + } + + if oriIndex == nil { + if index.Type == core.UniqueType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addUnique(table.Name, name) + } else if index.Type == core.IndexType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addIndex(table.Name, name) + } + if err != nil { + return err + } + } + } + + for name2, index2 := range oriTable.Indexes { + if _, ok := foundIndexNames[name2]; !ok { + sql := engine.dialect.DropIndexSql(table.Name, index2) + _, err = engine.Exec(sql) + if err != nil { + return err + } + } + } + } + } + return nil +} + func (engine *Engine) unMap(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() diff --git a/postgres_dialect.go b/postgres_dialect.go index 61e75881..a088664c 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -108,10 +108,28 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { " AND column_name = ?", args }*/ +func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) +} + +func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { + quote := db.Quote + //var unique string + var idxName string = index.Name + if !strings.HasPrefix(idxName, "UQE_") && + !strings.HasPrefix(idxName, "IDX_") { + if index.Type == core.UniqueType { + idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) + } else { + idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) + } + } + return fmt.Sprintf("DROP INDEX %v", quote(idxName)) +} + func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, error) { args := []interface{}{tableName, col.Name} - - //query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" rows, err := db.DB().Query(query, args...) @@ -120,10 +138,7 @@ func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, err } defer rows.Close() - if rows.Next() { - return true, nil - } - return false, nil + return rows.Next(), nil } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { @@ -169,11 +184,7 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col } } - if isNullable == "YES" { - col.Nullable = true - } else { - col.Nullable = false - } + col.Nullable = (isNullable == "YES") switch dataType { case "character varying", "character": @@ -257,7 +268,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) return nil, err } indexName = strings.Trim(indexName, `" `) - + if strings.HasSuffix(indexName, "_pkey") { + continue + } if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { indexType = core.UniqueType } else { @@ -266,9 +279,6 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) cs := strings.Split(indexdef, "(") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") - if strings.HasSuffix(indexName, "_pkey") { - continue - } if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { newIdxName := indexName[5+len(tableName) : len(indexName)] if newIdxName != "" { diff --git a/session.go b/session.go index 01319376..b1636f4e 100644 --- a/session.go +++ b/session.go @@ -3344,6 +3344,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, args = append(args, session.Engine.NowTime(col.SQLType.Name)) } else if col.IsVersion && session.Statement.checkVersion { args = append(args, 1) + //} else if !col.DefaultIsEmpty { } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { diff --git a/statement.go b/statement.go index c19247a2..33aa20a7 100644 --- a/statement.go +++ b/statement.go @@ -281,6 +281,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } + fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -291,6 +292,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, fieldType := reflect.TypeOf(fieldValue.Interface()) requiredField := useAllCols + includeNil := useAllCols if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { if b { requiredField = true