From 6384ada2bba6a5fa4245dd93a03f9de5c88df51b Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Sat, 31 May 2014 12:19:46 +0800 Subject: [PATCH 1/2] bug fixed --- engine.go | 10 +++++++--- session.go | 1 + statement.go | 5 +++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/engine.go b/engine.go index d2db0cc7..77a636d4 100644 --- a/engine.go +++ b/engine.go @@ -134,8 +134,8 @@ func (engine *Engine) NoCascade() *Session { // Set a table use a special cacher func (engine *Engine) MapCacher(bean interface{}, cacher core.Cacher) { v := rValue(bean) - engine.autoMapType(v) - engine.Tables[v.Type()].Cacher = cacher + tb := engine.autoMapType(v) + tb.Cacher = cacher } // NewDB provides an interface to operate database directly @@ -483,7 +483,7 @@ func (engine *Engine) Desc(colNames ...string) *Session { return session.Desc(colNames...) } -// Method Asc will generate "ORDER BY column1 DESC, column2 Asc" +// Method Asc will generate "ORDER BY column1,column2 Asc" // This method can chainable use. // // engine.Desc("name").Asc("age").Find(&users) @@ -561,6 +561,7 @@ func (engine *Engine) newTable() *core.Table { } func (engine *Engine) mapType(v reflect.Value) *core.Table { + fmt.Println("has", v.NumField(), "fields") t := v.Type() table := engine.newTable() method := v.MethodByName("TableName") @@ -587,10 +588,12 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { for i := 0; i < t.NumField(); i++ { tag := t.Field(i).Tag + ormTagStr := tag.Get(engine.TagIdentifier) var col *core.Column fieldValue := v.Field(i) fieldType := fieldValue.Type() + fmt.Println(table.Name, "===", t.Field(i).Name) if ormTagStr != "" { col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, @@ -764,6 +767,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { sqlType = core.SQLType{core.Text, 0, 0} } else { sqlType = core.Type2SQLType(fieldType) + fmt.Println(t.Field(i).Name, "...", sqlType) } col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, diff --git a/session.go b/session.go index 591850af..cfa5696e 100644 --- a/session.go +++ b/session.go @@ -3333,6 +3333,7 @@ func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, args = append(args, session.Engine.NowTime(col.SQLType.Name)) } else if col.IsVersion && session.Statement.checkVersion { args = append(args, 1) + //} else if !col.DefaultIsEmpty { } else { arg, err := session.value2Interface(col, fieldValue) if err != nil { diff --git a/statement.go b/statement.go index c19247a2..865e67d3 100644 --- a/statement.go +++ b/statement.go @@ -264,6 +264,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, colNames := make([]string, 0) var args = make([]interface{}, 0) + fmt.Println(table.ColumnsSeq()) for _, col := range table.Columns() { if !includeVersion && col.IsVersion { continue @@ -281,6 +282,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } + fmt.Println("===", col.Name) fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -291,6 +293,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, fieldType := reflect.TypeOf(fieldValue.Interface()) requiredField := useAllCols + includeNil := useAllCols if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { if b { requiredField = true @@ -323,9 +326,11 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, goto APPEND } + fmt.Println(col.Name, "is", fieldValue) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { + fmt.Println(col.Name, "is nil") args = append(args, nil) colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) } From 6d1a0ac0b0a8b354671be541d930a003c085d7b1 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Wed, 11 Jun 2014 14:01:14 +0800 Subject: [PATCH 2/2] add Sync2 for a new sync database struct methods --- engine.go | 135 +++++++++++++++++++++++++++++++++++++++++++- postgres_dialect.go | 40 ++++++++----- statement.go | 5 +- 3 files changed, 158 insertions(+), 22 deletions(-) diff --git a/engine.go b/engine.go index 77a636d4..51b64502 100644 --- a/engine.go +++ b/engine.go @@ -561,7 +561,6 @@ func (engine *Engine) newTable() *core.Table { } func (engine *Engine) mapType(v reflect.Value) *core.Table { - fmt.Println("has", v.NumField(), "fields") t := v.Type() table := engine.newTable() method := v.MethodByName("TableName") @@ -593,7 +592,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { var col *core.Column fieldValue := v.Field(i) fieldType := fieldValue.Type() - fmt.Println(table.Name, "===", t.Field(i).Name) if ormTagStr != "" { col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, @@ -767,7 +765,6 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { sqlType = core.SQLType{core.Text, 0, 0} } else { sqlType = core.Type2SQLType(fieldType) - fmt.Println(t.Field(i).Name, "...", sqlType) } col = core.NewColumn(engine.ColumnMapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, sqlType.DefaultLength, @@ -1016,6 +1013,138 @@ func (engine *Engine) Sync(beans ...interface{}) error { return nil } +func (engine *Engine) Sync2(beans ...interface{}) error { + tables, err := engine.DBMetas() + if err != nil { + return err + } + + for _, bean := range beans { + table := engine.autoMap(bean) + + var oriTable *core.Table + for _, tb := range tables { + if tb.Name == table.Name { + oriTable = tb + break + } + } + + if oriTable == nil { + err = engine.CreateTables(bean) + if err != nil { + return err + } + + err = engine.CreateUniques(bean) + if err != nil { + return err + } + + err = engine.CreateIndexes(bean) + if err != nil { + return err + } + } else { + for _, col := range table.Columns() { + var oriCol *core.Column + for _, col2 := range oriTable.Columns() { + if col.Name == col2.Name { + oriCol = col2 + break + } + } + + if oriCol != nil { + if col.SQLType.Name != oriCol.SQLType.Name { + if col.SQLType.Name == core.Text && + oriCol.SQLType.Name == core.Varchar { + // currently only support mysql + if engine.dialect.DBType() == core.MYSQL { + _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + } else { + engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + } + } else { + engine.LogWarn("Table %s Column %s Old data type is %s, new data type is %s", + table.Name, col.Name, oriCol.SQLType.Name, col.SQLType.Name) + } + } + if col.Default != oriCol.Default { + engine.LogWarn("Table %s Column %s Old default is %s, new default is %s", + table.Name, col.Name, oriCol.Default, col.Default) + } + if col.Nullable != oriCol.Nullable { + engine.LogWarn("Table %s Column %s Old nullable is %v, new nullable is %v", + table.Name, col.Name, oriCol.Nullable, col.Nullable) + } + } else { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addColumn(col.Name) + } + if err != nil { + return err + } + } + + var foundIndexNames = make(map[string]bool) + + for name, index := range table.Indexes { + var oriIndex *core.Index + for name2, index2 := range oriTable.Indexes { + if index.Equal(index2) { + oriIndex = index2 + foundIndexNames[name2] = true + break + } + } + + if oriIndex != nil { + if oriIndex.Type != index.Type { + sql := engine.dialect.DropIndexSql(table.Name, oriIndex) + _, err = engine.Exec(sql) + if err != nil { + return err + } + oriIndex = nil + } + } + + if oriIndex == nil { + if index.Type == core.UniqueType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addUnique(table.Name, name) + } else if index.Type == core.IndexType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addIndex(table.Name, name) + } + if err != nil { + return err + } + } + } + + for name2, index2 := range oriTable.Indexes { + if _, ok := foundIndexNames[name2]; !ok { + sql := engine.dialect.DropIndexSql(table.Name, index2) + _, err = engine.Exec(sql) + if err != nil { + return err + } + } + } + } + } + return nil +} + func (engine *Engine) unMap(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() diff --git a/postgres_dialect.go b/postgres_dialect.go index 61e75881..a088664c 100644 --- a/postgres_dialect.go +++ b/postgres_dialect.go @@ -108,10 +108,28 @@ func (db *postgres) TableCheckSql(tableName string) (string, []interface{}) { " AND column_name = ?", args }*/ +func (db *postgres) ModifyColumnSql(tableName string, col *core.Column) string { + return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", + tableName, col.Name, db.SqlType(col)) +} + +func (db *postgres) DropIndexSql(tableName string, index *core.Index) string { + quote := db.Quote + //var unique string + var idxName string = index.Name + if !strings.HasPrefix(idxName, "UQE_") && + !strings.HasPrefix(idxName, "IDX_") { + if index.Type == core.UniqueType { + idxName = fmt.Sprintf("UQE_%v_%v", tableName, index.Name) + } else { + idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) + } + } + return fmt.Sprintf("DROP INDEX %v", quote(idxName)) +} + func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, error) { args := []interface{}{tableName, col.Name} - - //query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + " AND column_name = $2" rows, err := db.DB().Query(query, args...) @@ -120,10 +138,7 @@ func (db *postgres) IsColumnExist(tableName string, col *core.Column) (bool, err } defer rows.Close() - if rows.Next() { - return true, nil - } - return false, nil + return rows.Next(), nil } func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) { @@ -169,11 +184,7 @@ func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Col } } - if isNullable == "YES" { - col.Nullable = true - } else { - col.Nullable = false - } + col.Nullable = (isNullable == "YES") switch dataType { case "character varying", "character": @@ -257,7 +268,9 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) return nil, err } indexName = strings.Trim(indexName, `" `) - + if strings.HasSuffix(indexName, "_pkey") { + continue + } if strings.HasPrefix(indexdef, "CREATE UNIQUE INDEX") { indexType = core.UniqueType } else { @@ -266,9 +279,6 @@ func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) cs := strings.Split(indexdef, "(") colNames = strings.Split(cs[1][0:len(cs[1])-1], ",") - if strings.HasSuffix(indexName, "_pkey") { - continue - } if strings.HasPrefix(indexName, "IDX_"+tableName) || strings.HasPrefix(indexName, "UQE_"+tableName) { newIdxName := indexName[5+len(tableName) : len(indexName)] if newIdxName != "" { diff --git a/statement.go b/statement.go index 865e67d3..33aa20a7 100644 --- a/statement.go +++ b/statement.go @@ -264,7 +264,6 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, colNames := make([]string, 0) var args = make([]interface{}, 0) - fmt.Println(table.ColumnsSeq()) for _, col := range table.Columns() { if !includeVersion && col.IsVersion { continue @@ -282,7 +281,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } - fmt.Println("===", col.Name) + fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -326,11 +325,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, goto APPEND } - fmt.Println(col.Name, "is", fieldValue) if fieldType.Kind() == reflect.Ptr { if fieldValue.IsNil() { if includeNil { - fmt.Println(col.Name, "is nil") args = append(args, nil) colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) }