From f2b20d510ce8c0ae6d029902c9e267a0cd910fd2 Mon Sep 17 00:00:00 2001 From: Nash Tsai Date: Mon, 16 Dec 2013 23:13:23 +0800 Subject: [PATCH] implemented composite key --- base_test.go | 137 ++++++++++++++++++++++++++++++++++++++------------- engine.go | 4 +- session.go | 65 ++++++++++++++++-------- statement.go | 128 ++++++++++++++++++++++++++++++++++++++++++----- table.go | 41 ++++++++++----- 5 files changed, 294 insertions(+), 81 deletions(-) diff --git a/base_test.go b/base_test.go index a54667c8..4067185f 100644 --- a/base_test.go +++ b/base_test.go @@ -1892,21 +1892,21 @@ func (p *ProcessorsStruct) AfterDelete() { } func testProcessors(engine *Engine, t *testing.T) { - tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) - if err != nil { - t.Error(err) - panic(err) - } + // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + // if err != nil { + // t.Error(err) + // panic(err) + // } - tempEngine.ShowSQL = true - err = tempEngine.DropTables(&ProcessorsStruct{}) + engine.ShowSQL = true + err := engine.DropTables(&ProcessorsStruct{}) if err != nil { t.Error(err) panic(err) } p := &ProcessorsStruct{} - err = tempEngine.CreateTables(&ProcessorsStruct{}) + err = engine.CreateTables(&ProcessorsStruct{}) if err != nil { t.Error(err) panic(err) @@ -1928,7 +1928,7 @@ func testProcessors(engine *Engine, t *testing.T) { } } - _, err = tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) + _, err = engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) if err != nil { t.Error(err) panic(err) @@ -1948,7 +1948,7 @@ func testProcessors(engine *Engine, t *testing.T) { } p2 := &ProcessorsStruct{} - _, err = tempEngine.Id(p.Id).Get(p2) + _, err = engine.Id(p.Id).Get(p2) if err != nil { t.Error(err) panic(err) @@ -1987,7 +1987,7 @@ func testProcessors(engine *Engine, t *testing.T) { p = p2 // reset - _, err = tempEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) + _, err = engine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) if err != nil { t.Error(err) panic(err) @@ -2007,7 +2007,7 @@ func testProcessors(engine *Engine, t *testing.T) { } p2 = &ProcessorsStruct{} - _, err = tempEngine.Id(p.Id).Get(p2) + _, err = engine.Id(p.Id).Get(p2) if err != nil { t.Error(err) panic(err) @@ -2045,7 +2045,7 @@ func testProcessors(engine *Engine, t *testing.T) { } p = p2 // reset - _, err = tempEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) + _, err = engine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) if err != nil { t.Error(err) panic(err) @@ -2069,7 +2069,7 @@ func testProcessors(engine *Engine, t *testing.T) { pslice := make([]*ProcessorsStruct, 0) pslice = append(pslice, &ProcessorsStruct{}) pslice = append(pslice, &ProcessorsStruct{}) - cnt, err := tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) + cnt, err := engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) if err != nil { t.Error(err) panic(err) @@ -2095,7 +2095,7 @@ func testProcessors(engine *Engine, t *testing.T) { for _, elem := range pslice { p = &ProcessorsStruct{} - _, err = tempEngine.Id(elem.Id).Get(p) + _, err = engine.Id(elem.Id).Get(p) if err != nil { t.Error(err) panic(err) @@ -2118,27 +2118,27 @@ func testProcessors(engine *Engine, t *testing.T) { } func testProcessorsTx(engine *Engine, t *testing.T) { - tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) + // if err != nil { + // t.Error(err) + // panic(err) + // } + + // tempEngine.ShowSQL = true + err := engine.DropTables(&ProcessorsStruct{}) if err != nil { t.Error(err) panic(err) } - tempEngine.ShowSQL = true - err = tempEngine.DropTables(&ProcessorsStruct{}) - if err != nil { - t.Error(err) - panic(err) - } - - err = tempEngine.CreateTables(&ProcessorsStruct{}) + err = engine.CreateTables(&ProcessorsStruct{}) if err != nil { t.Error(err) panic(err) } // test insert processors with tx rollback - session := tempEngine.NewSession() + session := engine.NewSession() err = session.Begin() if err != nil { t.Error(err) @@ -2200,7 +2200,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { } session.Close() p2 := &ProcessorsStruct{} - _, err = tempEngine.Id(p.Id).Get(p2) + _, err = engine.Id(p.Id).Get(p2) if err != nil { t.Error(err) panic(err) @@ -2214,7 +2214,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- // test insert processors with tx commit - session = tempEngine.NewSession() + session = engine.NewSession() err = session.Begin() if err != nil { t.Error(err) @@ -2261,7 +2261,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { } session.Close() p2 = &ProcessorsStruct{} - _, err = tempEngine.Id(p.Id).Get(p2) + _, err = engine.Id(p.Id).Get(p2) if err != nil { t.Error(err) panic(err) @@ -2283,7 +2283,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- // test update processors with tx rollback - session = tempEngine.NewSession() + session = engine.NewSession() err = session.Begin() if err != nil { t.Error(err) @@ -2347,7 +2347,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = tempEngine.Id(insertedId).Get(p2) + _, err = engine.Id(insertedId).Get(p2) if err != nil { t.Error(err) panic(err) @@ -2368,7 +2368,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- // test update processors with tx commit - session = tempEngine.NewSession() + session = engine.NewSession() err = session.Begin() if err != nil { t.Error(err) @@ -2415,7 +2415,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { } session.Close() p2 = &ProcessorsStruct{} - _, err = tempEngine.Id(insertedId).Get(p2) + _, err = engine.Id(insertedId).Get(p2) if err != nil { t.Error(err) panic(err) @@ -2436,7 +2436,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- // test delete processors with tx rollback - session = tempEngine.NewSession() + session = engine.NewSession() err = session.Begin() if err != nil { t.Error(err) @@ -2500,7 +2500,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { session.Close() p2 = &ProcessorsStruct{} - _, err = tempEngine.Id(insertedId).Get(p2) + _, err = engine.Id(insertedId).Get(p2) if err != nil { t.Error(err) panic(err) @@ -2521,7 +2521,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) { // -- // test delete processors with tx commit - session = tempEngine.NewSession() + session = engine.NewSession() err = session.Begin() if err != nil { t.Error(err) @@ -3253,6 +3253,71 @@ func testNullValue(engine *Engine, t *testing.T) { } +type CompositeKey struct { + Id1 int64 `xorm:"id1 pk"` + Id2 int64 `xorm:"id2 pk"` + UpdateStr string +} + +func testCompositeKey(engine *Engine, t *testing.T) { + + err := engine.DropTables(&CompositeKey{}) + if err != nil { + t.Error(err) + panic(err) + } + + err = engine.CreateTables(&CompositeKey{}) + if err != nil { + t.Error(err) + panic(err) + } + + cnt, err := engine.Insert(&CompositeKey{11, 22, ""}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("failed to insert CompositeKey{11, 22}")) + } + + cnt, err = engine.Insert(&CompositeKey{11, 22, ""}) + if err == nil || cnt == 1 { + t.Error(errors.New("inserted CompositeKey{11, 22}")) + } + + var compositeKeyVal CompositeKey + has, err := engine.Id(PK{11, 22}).Get(&compositeKeyVal) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get CompositeKey{11, 22}")) + } + + // test passing PK ptr, this test seem failed withCache + has, err = engine.Id(&PK{11, 22}).Get(&compositeKeyVal) + if err != nil { + t.Error(err) + } else if !has { + t.Error(errors.New("can't get CompositeKey{11, 22}")) + } + + compositeKeyVal = CompositeKey{UpdateStr:"test1"} + cnt, err = engine.Id(PK{11, 22}).Update(&compositeKeyVal) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't update CompositeKey{11, 22}")) + } + + cnt, err = engine.Id(PK{11, 22}).Delete(&CompositeKey{}) + if err != nil { + t.Error(err) + } else if cnt != 1 { + t.Error(errors.New("can't delete CompositeKey{11, 22}")) + } +} + + func testAll(engine *Engine, t *testing.T) { fmt.Println("-------------- directCreateTable --------------") directCreateTable(engine, t) @@ -3363,4 +3428,6 @@ func testAll3(engine *Engine, t *testing.T) { testPointerData(engine, t) fmt.Println("-------------- insert null data --------------") testNullValue(engine, t) + fmt.Println("-------------- testCompositeKey --------------") + testCompositeKey(engine, t) } diff --git a/engine.go b/engine.go index 8d5887c2..33ddaf66 100644 --- a/engine.go +++ b/engine.go @@ -40,6 +40,8 @@ type dialect interface { GetIndexes(tableName string) (map[string]*Index, error) } +type PK []interface{} + // Engine is the major struct of xorm, it means a database manager. // Commonly, an application only need one engine type Engine struct { @@ -269,7 +271,7 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Session { } // Id mehtod provoide a condition as (id) = ? -func (engine *Engine) Id(id int64) *Session { +func (engine *Engine) Id(id interface{}) *Session { session := engine.NewSession() session.IsAutoClose = true return session.Id(id) diff --git a/session.go b/session.go index 6aaa234b..01853736 100644 --- a/session.go +++ b/session.go @@ -90,7 +90,7 @@ func (session *Session) Or(querystring string, args ...interface{}) *Session { } // Method Id provides converting id as a query condition -func (session *Session) Id(id int64) *Session { +func (session *Session) Id(id interface{}) *Session { session.Statement.Id(id) return session } @@ -490,7 +490,8 @@ func (session *Session) CreateUniques(bean interface{}) error { } func (session *Session) createOneTable() error { - sql := session.Statement.genCreateSQL() + sql := session.Statement.genCreateTableSQL() + session.Engine.LogDebug("create table sql: [", sql, "]") _, err := session.exec(sql) return err } @@ -1580,6 +1581,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data var v interface{} key := col.Name fieldType := fieldValue.Type() + + //fmt.Println("column name:", key, ", fieldType:", fieldType.String()) switch fieldType.Kind() { case reflect.Complex64, reflect.Complex128: @@ -1730,19 +1733,22 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } case reflect.Ptr: // !nashtsai! TODO merge duplicated codes above - typeStr := fieldType.String() - switch typeStr { - case "*string": + //typeStr := fieldType.String() + switch fieldType { + // case "*string": + case reflect.TypeOf(&c_EMPTY_STRING): x := string(data) fieldValue.Set(reflect.ValueOf(&x)) - case "*bool": + // case "*bool": + case reflect.TypeOf(&c_BOOL_DEFAULT): d := string(data) v, err := strconv.ParseBool(d) if err != nil { return errors.New("arg " + key + " as bool: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&v)) - case "*complex64": + // case "*complex64": + case reflect.TypeOf(&c_COMPLEX64_DEFAULT): var x complex64 err := json.Unmarshal(data, &x) if err != nil { @@ -1750,7 +1756,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return err } fieldValue.Set(reflect.ValueOf(&x)) - case "*complex128": + // case "*complex128": + case reflect.TypeOf(&c_COMPLEX128_DEFAULT): var x complex128 err := json.Unmarshal(data, &x) if err != nil { @@ -1758,13 +1765,15 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return err } fieldValue.Set(reflect.ValueOf(&x)) - case "*float64": + // case "*float64": + case reflect.TypeOf(&c_FLOAT64_DEFAULT): x, err := strconv.ParseFloat(string(data), 64) if err != nil { return errors.New("arg " + key + " as float64: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) - case "*float32": + // case "*float32": + case reflect.TypeOf(&c_FLOAT32_DEFAULT): var x float32 x1, err := strconv.ParseFloat(string(data), 32) if err != nil { @@ -1772,7 +1781,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } x = float32(x1) fieldValue.Set(reflect.ValueOf(&x)) - case "*time.Time": + // case "*time.Time": + case reflect.TypeOf(&c_TIME_DEFAULT): sdata := strings.TrimSpace(string(data)) var x time.Time var err error @@ -1809,14 +1819,16 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data v = x fieldValue.Set(reflect.ValueOf(&x)) - case "*uint64": + // case "*uint64": + case reflect.TypeOf(&c_UINT64_DEFAULT): var x uint64 x, err := strconv.ParseUint(string(data), 10, 64) if err != nil { return errors.New("arg " + key + " as int: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) - case "*uint": + // case "*uint": + case reflect.TypeOf(&c_UINT_DEFAULT): var x uint x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -1824,7 +1836,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } x = uint(x1) fieldValue.Set(reflect.ValueOf(&x)) - case "*uint32": + // case "*uint32": + case reflect.TypeOf(&c_UINT32_DEFAULT): var x uint32 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -1832,7 +1845,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } x = uint32(x1) fieldValue.Set(reflect.ValueOf(&x)) - case "*uint8": + // case "*uint8": + case reflect.TypeOf(&c_UINT8_DEFAULT): var x uint8 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -1840,7 +1854,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } x = uint8(x1) fieldValue.Set(reflect.ValueOf(&x)) - case "*uint16": + // case "*uint16": + case reflect.TypeOf(&c_UINT16_DEFAULT): var x uint16 x1, err := strconv.ParseUint(string(data), 10, 64) if err != nil { @@ -1848,7 +1863,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data } x = uint16(x1) fieldValue.Set(reflect.ValueOf(&x)) - case "*int64": + // case "*int64": + case reflect.TypeOf(&c_INT64_DEFAULT): sdata := string(data) var x int64 var err error @@ -1872,7 +1888,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) - case "*int": + // case "*int": + case reflect.TypeOf(&c_INT_DEFAULT): sdata := string(data) var x int var x1 int64 @@ -1900,7 +1917,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) - case "*int32": + // case "*int32": + case reflect.TypeOf(&c_INT32_DEFAULT): sdata := string(data) var x int32 var x1 int64 @@ -1928,7 +1946,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) - case "*int8": + // case "*int8": + case reflect.TypeOf(&c_INT8_DEFAULT): sdata := string(data) var x int8 var x1 int64 @@ -1956,7 +1975,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data return errors.New("arg " + key + " as int: " + err.Error()) } fieldValue.Set(reflect.ValueOf(&x)) - case "*int16": + // case "*int16": + case reflect.TypeOf(&c_INT16_DEFAULT): sdata := string(data) var x int16 var x1 int64 @@ -2494,6 +2514,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } var condition = "" + session.Statement.processIdParam() st := session.Statement defer session.Statement.Init() if st.WhereStr != "" { @@ -2680,6 +2701,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) { false, session.Statement.allUseBool, session.Statement.boolColumnMap) var condition = "" + + session.Statement.processIdParam() if session.Statement.WhereStr != "" { condition = session.Statement.WhereStr if len(colNames) > 0 { diff --git a/statement.go b/statement.go index ed7937cd..6fdcd56a 100644 --- a/statement.go +++ b/statement.go @@ -9,6 +9,27 @@ import ( "time" ) +// !nashtsai! treat following var as interal const values +var ( + c_EMPTY_STRING = "" + c_BOOL_DEFAULT = false + c_COMPLEX64_DEFAULT = complex64(0) + c_COMPLEX128_DEFAULT = complex128(0) + c_FLOAT32_DEFAULT = float32(0) + c_FLOAT64_DEFAULT = float64(0) + c_INT64_DEFAULT = int64(0) + c_UINT64_DEFAULT = uint64(0) + c_INT32_DEFAULT = int32(0) + c_UINT32_DEFAULT = uint32(0) + c_INT16_DEFAULT = int16(0) + c_UINT16_DEFAULT = uint16(0) + c_INT8_DEFAULT = int8(0) + c_UINT8_DEFAULT = uint8(0) + c_INT_DEFAULT = int(0) + c_UINT_DEFAULT = uint(0) + c_TIME_DEFAULT time.Time = time.Unix(0, 0) +) + // statement save all the sql info for executing SQL type Statement struct { RefTable *Table @@ -16,6 +37,7 @@ type Statement struct { Start int LimitN int WhereStr string + IdParam *PK Params []interface{} OrderStr string JoinStr string @@ -256,7 +278,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, if fieldValue.IsNil() { if includeNil { args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) } continue } else if !fieldValue.IsValid() { @@ -376,7 +398,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{}, } args = append(args, val) - colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) } return colNames, args @@ -394,15 +416,40 @@ func (statement *Statement) TableName() string { return "" } -// Generate "Where id = ? " statment -func (statement *Statement) Id(id int64) *Statement { - if statement.WhereStr == "" { - statement.WhereStr = "(id)=?" - statement.Params = []interface{}{id} - } else { - statement.WhereStr = statement.WhereStr + " AND (id)=?" - statement.Params = append(statement.Params, id) +// Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" +func (statement *Statement) Id(id interface{}) *Statement { + + idValue := reflect.ValueOf(id) + idType := reflect.TypeOf(idValue.Interface()) + + switch idType { + case reflect.TypeOf(&PK{}): + if pkPtr, ok := (id).(*PK); ok { + statement.IdParam = pkPtr + } + case reflect.TypeOf(PK{}): + if pk, ok := (id).(PK); ok { + statement.IdParam = &pk + } + default: + // TODO treat as int primitve for now, need to handle type check + statement.IdParam = &PK{id} + + // !nashtsai! REVIEW although it will be user's mistake if called Id() twice with + // different value and Id should be PK's field name, however, at this stage probably + // can't tell which table is gonna be used + // if statement.WhereStr == "" { + // statement.WhereStr = "(id)=?" + // statement.Params = []interface{}{id} + // } else { + // // TODO what if id param has already passed + // statement.WhereStr = statement.WhereStr + " AND (id)=?" + // statement.Params = append(statement.Params, id) + // } } + + // !nashtsai! perhaps no need to validate pk values' type just let sql complaint happen + return statement } @@ -559,14 +606,37 @@ func (statement *Statement) genColumnStr() string { return strings.Join(colNames, ", ") } -func (statement *Statement) genCreateSQL() string { +func (statement *Statement) genCreateTableSQL() string { sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " (" + + + pkList := []string{} + for _, colName := range statement.RefTable.ColumnsSeq { col := statement.RefTable.Columns[colName] - sql += col.String(statement.Engine.dialect) + if col.IsPrimaryKey { + pkList = append(pkList, col.Name) + } + } + + statement.Engine.LogDebug("len:", len(pkList)) + for _, colName := range statement.RefTable.ColumnsSeq { + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey && len(pkList) == 1 { + sql += col.String(statement.Engine.dialect) + } else { + sql += col.stringNoPk(statement.Engine.dialect) + } sql = strings.TrimSpace(sql) sql += ", " } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += strings.Join(pkList, ",") + sql += " ), " + } + sql = sql[:len(sql)-2] + ")" if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" { sql += " ENGINE=" + statement.StoreEngine @@ -702,11 +772,14 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.IsDistinct { distinct = "DISTINCT " } + + // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr, statement.Engine.Quote(statement.TableName())) if statement.JoinStr != "" { a = fmt.Sprintf("%v %v", a, statement.JoinStr) } + statement.processIdParam() if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) if statement.ConditionStr != "" { @@ -732,3 +805,34 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } return } + +func (statement *Statement) processIdParam() { + + if statement.IdParam != nil { + i := 0 + colCnt := len(statement.RefTable.ColumnsSeq) + for _, elem := range *(statement.IdParam) { + for ; i < colCnt; i++ { + colName := statement.RefTable.ColumnsSeq[i] + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey { + statement.And(fmt.Sprintf("%v=?", col.Name), elem) + i++ + break + } + } + } + + // !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it + // as empty string for now, so this will result sql exec failed instead of unexpected + // false update/delete + for ; i < colCnt; i++ { + colName := statement.RefTable.ColumnsSeq[i] + col := statement.RefTable.Columns[colName] + if col.IsPrimaryKey { + statement.And(fmt.Sprintf("%v=?", col.Name), "") + } + } + } +} + diff --git a/table.go b/table.go index 75ed8b2f..e408ea85 100644 --- a/table.go +++ b/table.go @@ -155,25 +155,25 @@ func Type2SQLType(t reflect.Type) (st SQLType) { } func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { - typeStr := t.String() has = true - switch typeStr { - case "*string": + switch t { + case reflect.TypeOf(&c_EMPTY_STRING): st = SQLType{Varchar, 255, 0} - case "*bool": + return + case reflect.TypeOf(&c_BOOL_DEFAULT): st = SQLType{Bool, 0, 0} - case "*complex64", "*complex128": + case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT): st = SQLType{Varchar, 64, 0} - case "*float32": + case reflect.TypeOf(&c_FLOAT32_DEFAULT): st = SQLType{Float, 0, 0} - case "*float64": + case reflect.TypeOf(&c_FLOAT64_DEFAULT): st = SQLType{Double, 0, 0} - case "*int64", "*uint64": + case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT): st = SQLType{BigInt, 0, 0} - case "*time.Time": + case reflect.TypeOf(&c_TIME_DEFAULT): st = SQLType{DateTime, 0, 0} - case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8": + case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT): st = SQLType{Int, 0, 0} default: has = false @@ -265,12 +265,29 @@ func (col *Column) String(d dialect) string { if col.IsPrimaryKey { sql += "PRIMARY KEY " + if col.IsAutoIncrement { + sql += d.AutoIncrStr() + " " + } } - if col.IsAutoIncrement { - sql += d.AutoIncrStr() + " " + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " } + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + return sql +} + +func (col *Column) stringNoPk(d dialect) string { + sql := d.QuoteStr() + col.Name + d.QuoteStr() + " " + + sql += d.SqlType(col) + " " + if col.Nullable { sql += "NULL " } else {