diff --git a/engine.go b/engine.go index 77a636d4..51b64502 100644 --- a/engine.go +++ b/engine.go @@ -561,7 +561,6 @@ func (engine *Engine) newTable() *core.Table { } func (engine *Engine) mapType(v reflect.Value) *core.Table { - fmt.Println("has", v.NumField(), "fields") t := v.Type() table := engine.newTable() method := v.MethodByName("TableName") @@ -593,7 +592,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { var col *core.Column fieldValue := v.Field(i) fieldType := fieldValue.Type() - fmt.Println(table.Name, "===", t.Field(i).Name) if ormTagStr != "" { col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, @@ -767,7 +765,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { sqlType = core.SQLType{core.Text, 0, 0} } else { sqlType = core.Type2SQLType(fieldType) - fmt.Println(t.Field(i).Name, "...", sqlType) } col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, @@ -1016,6 +1013,138 @@ func (engine *Engine) Sync(beans ...interface{}) error { return nil } +func (engine *Engine) Sync2(beans ...interface{}) error { + tables, err := engine.DBMetas() + if err != nil { + return err + } + + for _, bean := range beans { + table := engine.autoMap(bean) + + var oriTable *core.Table + for _, tb := range tables { + if tb.Name == table.Name { + oriTable = tb + break + } + } + + if oriTable == nil { + err = engine.CreateTables(bean) + if err != nil { + return err + } + + err = engine.CreateUniques(bean) + if err != nil { + return err + } + + err = engine.CreateIndexes(bean) + if err != nil { + return err + } + } else { + for _, col := range table.Columns() { + var oriCol *core.Column + for _, col2 := range oriTable.Columns() { + if col.Name == col2.Name { + oriCol = col2 + break + } + } + + if oriCol != nil { + if col.SQLType.Name != oriCol.SQLType.Name { + if col.SQLType.Name == core.Text && + oriCol.SQLType.Name == core.Varchar { + // currently only support mysql + if engine.dialect.DBType() == core.MYSQL { + _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + } else { + engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + } + } else { + engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + } + } + if col.Default != oriCol.Default { + engine.LogWarn("Table %s Column %s Old default is %s, new default is %s", + table.Name, col.Name, oriCol.Default, col.Default) + } + if col.Nullable != oriCol.Nullable { + engine.LogWarn("Table %s Column %s Old nullable is %v, new nullable is %v", + table.Name, col.Name, oriCol.Nullable, col.Nullable) + } + } else { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addColumn(col.Name) + } + if err != nil { + return err + } + } + + var foundIndexNames = make(map[string]bool) + + for name, index := range table.Indexes { + var oriIndex *core.Index + for name2, index2 := range oriTable.Indexes { + if index.Equal(index2) { + oriIndex = index2 + foundIndexNames[name2] = true + break + } + } + + if oriIndex != nil { + if oriIndex.Type != index.Type { + sql := engine.dialect.DropIndexSql(table.Name, oriIndex) + _, err = engine.Exec(sql) + if err != nil { + return err + } + oriIndex = nil + } + } + + if oriIndex == nil { + if index.Type == core.UniqueType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addUnique(table.Name, name) + } else if index.Type == core.IndexType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addIndex(table.Name, name) + } + if err != nil { + return err + } + } + } + + for name2, index2 := range oriTable.Indexes { + if _, ok := foundIndexNames[name2]; !ok { + sql := engine.dialect.DropIndexSql(table.Name, index2) + _, err = engine.Exec(sql) + if err != nil { + return err + } + } + } + } + } + return nil +} + func (engine *Engine) unMap(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() diff --git a/postgres_dialect.go b/postgres_dialect.go index 61e75881..a088664c 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -108,10 +108,28 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { " AND column_name = ?", args }*/ +func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) +} + +func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { + quote := db.Quote + //var unique string + var idxName string = index.Name + if !strings.HasPrefix(idxName, "UQE_") && + !strings.HasPrefix(idxName, "IDX_") { + if index.Type == core.UniqueType { + idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) + } else { + idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) + } + } + return fmt.Sprintf("DROP INDEX %v", quote(idxName)) +} + func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, error) { args := []interface{}{tableName, col.Name} - - //query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" rows, err := db.DB().Query(query, args...) @@ -120,10 +138,7 @@ func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, err } defer rows.Close() - if rows.Next() { - return true, nil - } - return false, nil + return rows.Next(), nil } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { @@ -169,11 +184,7 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col } } - if isNullable == "YES" { - col.Nullable = true - } else { - col.Nullable = false - } + col.Nullable = (isNullable == "YES") switch dataType { case "character varying", "character": @@ -257,7 +268,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) return nil, err } indexName = strings.Trim(indexName, `" `) - + if strings.HasSuffix(indexName, "_pkey") { + continue + } if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { indexType = core.UniqueType } else { @@ -266,9 +279,6 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) cs := strings.Split(indexdef, "(") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") - if strings.HasSuffix(indexName, "_pkey") { - continue - } if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { newIdxName := indexName[5+len(tableName) : len(indexName)] if newIdxName != "" { diff --git a/statement.go b/statement.go index 865e67d3..33aa20a7 100644 --- a/statement.go +++ b/statement.go @@ -264,7 +264,6 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, colNames := make([]string, 0) var args = make([]interface{}, 0) - fmt.Println(table.ColumnsSeq()) for _, col := range table.Columns() { if !includeVersion && col.IsVersion { continue @@ -282,7 +281,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } - fmt.Println("===", col.Name) + fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -326,11 +325,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, goto APPEND } - fmt.Println(col.Name, "is", fieldValue) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { - fmt.Println(col.Name, "is nil") args = append(args, nil) colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) }