implemented composite key

This commit is contained in:
Nash Tsai 2013-12-16 23:13:23 +08:00
parent 8b84f5d692
commit f2b20d510c
5 changed files with 294 additions and 81 deletions

View File

@ -1892,21 +1892,21 @@ func (p *ProcessorsStruct) AfterDelete() {
} }
func testProcessors(engine *Engine, t *testing.T) { func testProcessors(engine *Engine, t *testing.T) {
tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
if err != nil { // if err != nil {
t.Error(err) // t.Error(err)
panic(err) // panic(err)
} // }
tempEngine.ShowSQL = true engine.ShowSQL = true
err = tempEngine.DropTables(&ProcessorsStruct{}) err := engine.DropTables(&ProcessorsStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
p := &ProcessorsStruct{} p := &ProcessorsStruct{}
err = tempEngine.CreateTables(&ProcessorsStruct{}) err = engine.CreateTables(&ProcessorsStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1928,7 +1928,7 @@ func testProcessors(engine *Engine, t *testing.T) {
} }
} }
_, err = tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p) _, err = engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1948,7 +1948,7 @@ func testProcessors(engine *Engine, t *testing.T) {
} }
p2 := &ProcessorsStruct{} p2 := &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2) _, err = engine.Id(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1987,7 +1987,7 @@ func testProcessors(engine *Engine, t *testing.T) {
p = p2 // reset p = p2 // reset
_, err = tempEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p) _, err = engine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2007,7 +2007,7 @@ func testProcessors(engine *Engine, t *testing.T) {
} }
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2) _, err = engine.Id(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2045,7 +2045,7 @@ func testProcessors(engine *Engine, t *testing.T) {
} }
p = p2 // reset p = p2 // reset
_, err = tempEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p) _, err = engine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2069,7 +2069,7 @@ func testProcessors(engine *Engine, t *testing.T) {
pslice := make([]*ProcessorsStruct, 0) pslice := make([]*ProcessorsStruct, 0)
pslice = append(pslice, &ProcessorsStruct{}) pslice = append(pslice, &ProcessorsStruct{})
pslice = append(pslice, &ProcessorsStruct{}) pslice = append(pslice, &ProcessorsStruct{})
cnt, err := tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice) cnt, err := engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2095,7 +2095,7 @@ func testProcessors(engine *Engine, t *testing.T) {
for _, elem := range pslice { for _, elem := range pslice {
p = &ProcessorsStruct{} p = &ProcessorsStruct{}
_, err = tempEngine.Id(elem.Id).Get(p) _, err = engine.Id(elem.Id).Get(p)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2118,27 +2118,27 @@ func testProcessors(engine *Engine, t *testing.T) {
} }
func testProcessorsTx(engine *Engine, t *testing.T) { func testProcessorsTx(engine *Engine, t *testing.T) {
tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName) // tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
// if err != nil {
// t.Error(err)
// panic(err)
// }
// tempEngine.ShowSQL = true
err := engine.DropTables(&ProcessorsStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
tempEngine.ShowSQL = true err = engine.CreateTables(&ProcessorsStruct{})
err = tempEngine.DropTables(&ProcessorsStruct{})
if err != nil {
t.Error(err)
panic(err)
}
err = tempEngine.CreateTables(&ProcessorsStruct{})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
// test insert processors with tx rollback // test insert processors with tx rollback
session := tempEngine.NewSession() session := engine.NewSession()
err = session.Begin() err = session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -2200,7 +2200,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
} }
session.Close() session.Close()
p2 := &ProcessorsStruct{} p2 := &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2) _, err = engine.Id(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2214,7 +2214,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// -- // --
// test insert processors with tx commit // test insert processors with tx commit
session = tempEngine.NewSession() session = engine.NewSession()
err = session.Begin() err = session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -2261,7 +2261,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
} }
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2) _, err = engine.Id(p.Id).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2283,7 +2283,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// -- // --
// test update processors with tx rollback // test update processors with tx rollback
session = tempEngine.NewSession() session = engine.NewSession()
err = session.Begin() err = session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -2347,7 +2347,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(insertedId).Get(p2) _, err = engine.Id(insertedId).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2368,7 +2368,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// -- // --
// test update processors with tx commit // test update processors with tx commit
session = tempEngine.NewSession() session = engine.NewSession()
err = session.Begin() err = session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -2415,7 +2415,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
} }
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(insertedId).Get(p2) _, err = engine.Id(insertedId).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2436,7 +2436,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// -- // --
// test delete processors with tx rollback // test delete processors with tx rollback
session = tempEngine.NewSession() session = engine.NewSession()
err = session.Begin() err = session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -2500,7 +2500,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
session.Close() session.Close()
p2 = &ProcessorsStruct{} p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(insertedId).Get(p2) _, err = engine.Id(insertedId).Get(p2)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -2521,7 +2521,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// -- // --
// test delete processors with tx commit // test delete processors with tx commit
session = tempEngine.NewSession() session = engine.NewSession()
err = session.Begin() err = session.Begin()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@ -3253,6 +3253,71 @@ func testNullValue(engine *Engine, t *testing.T) {
} }
type CompositeKey struct {
Id1 int64 `xorm:"id1 pk"`
Id2 int64 `xorm:"id2 pk"`
UpdateStr string
}
func testCompositeKey(engine *Engine, t *testing.T) {
err := engine.DropTables(&CompositeKey{})
if err != nil {
t.Error(err)
panic(err)
}
err = engine.CreateTables(&CompositeKey{})
if err != nil {
t.Error(err)
panic(err)
}
cnt, err := engine.Insert(&CompositeKey{11, 22, ""})
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("failed to insert CompositeKey{11, 22}"))
}
cnt, err = engine.Insert(&CompositeKey{11, 22, ""})
if err == nil || cnt == 1 {
t.Error(errors.New("inserted CompositeKey{11, 22}"))
}
var compositeKeyVal CompositeKey
has, err := engine.Id(PK{11, 22}).Get(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if !has {
t.Error(errors.New("can't get CompositeKey{11, 22}"))
}
// test passing PK ptr, this test seem failed withCache
has, err = engine.Id(&PK{11, 22}).Get(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if !has {
t.Error(errors.New("can't get CompositeKey{11, 22}"))
}
compositeKeyVal = CompositeKey{UpdateStr:"test1"}
cnt, err = engine.Id(PK{11, 22}).Update(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}"))
}
cnt, err = engine.Id(PK{11, 22}).Delete(&CompositeKey{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't delete CompositeKey{11, 22}"))
}
}
func testAll(engine *Engine, t *testing.T) { func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------") fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t) directCreateTable(engine, t)
@ -3363,4 +3428,6 @@ func testAll3(engine *Engine, t *testing.T) {
testPointerData(engine, t) testPointerData(engine, t)
fmt.Println("-------------- insert null data --------------") fmt.Println("-------------- insert null data --------------")
testNullValue(engine, t) testNullValue(engine, t)
fmt.Println("-------------- testCompositeKey --------------")
testCompositeKey(engine, t)
} }

View File

@ -40,6 +40,8 @@ type dialect interface {
GetIndexes(tableName string) (map[string]*Index, error) GetIndexes(tableName string) (map[string]*Index, error)
} }
type PK []interface{}
// Engine is the major struct of xorm, it means a database manager. // Engine is the major struct of xorm, it means a database manager.
// Commonly, an application only need one engine // Commonly, an application only need one engine
type Engine struct { type Engine struct {
@ -269,7 +271,7 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
} }
// Id mehtod provoide a condition as (id) = ? // Id mehtod provoide a condition as (id) = ?
func (engine *Engine) Id(id int64) *Session { func (engine *Engine) Id(id interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
return session.Id(id) return session.Id(id)

View File

@ -90,7 +90,7 @@ func (session *Session) Or(querystring string, args ...interface{}) *Session {
} }
// Method Id provides converting id as a query condition // Method Id provides converting id as a query condition
func (session *Session) Id(id int64) *Session { func (session *Session) Id(id interface{}) *Session {
session.Statement.Id(id) session.Statement.Id(id)
return session return session
} }
@ -490,7 +490,8 @@ func (session *Session) CreateUniques(bean interface{}) error {
} }
func (session *Session) createOneTable() error { func (session *Session) createOneTable() error {
sql := session.Statement.genCreateSQL() sql := session.Statement.genCreateTableSQL()
session.Engine.LogDebug("create table sql: [", sql, "]")
_, err := session.exec(sql) _, err := session.exec(sql)
return err return err
} }
@ -1580,6 +1581,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var v interface{} var v interface{}
key := col.Name key := col.Name
fieldType := fieldValue.Type() fieldType := fieldValue.Type()
//fmt.Println("column name:", key, ", fieldType:", fieldType.String()) //fmt.Println("column name:", key, ", fieldType:", fieldType.String())
switch fieldType.Kind() { switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
@ -1730,19 +1733,22 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
case reflect.Ptr: case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above // !nashtsai! TODO merge duplicated codes above
typeStr := fieldType.String() //typeStr := fieldType.String()
switch typeStr { switch fieldType {
case "*string": // case "*string":
case reflect.TypeOf(&c_EMPTY_STRING):
x := string(data) x := string(data)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*bool": // case "*bool":
case reflect.TypeOf(&c_BOOL_DEFAULT):
d := string(data) d := string(data)
v, err := strconv.ParseBool(d) v, err := strconv.ParseBool(d)
if err != nil { if err != nil {
return errors.New("arg " + key + " as bool: " + err.Error()) return errors.New("arg " + key + " as bool: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&v)) fieldValue.Set(reflect.ValueOf(&v))
case "*complex64": // case "*complex64":
case reflect.TypeOf(&c_COMPLEX64_DEFAULT):
var x complex64 var x complex64
err := json.Unmarshal(data, &x) err := json.Unmarshal(data, &x)
if err != nil { if err != nil {
@ -1750,7 +1756,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return err return err
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*complex128": // case "*complex128":
case reflect.TypeOf(&c_COMPLEX128_DEFAULT):
var x complex128 var x complex128
err := json.Unmarshal(data, &x) err := json.Unmarshal(data, &x)
if err != nil { if err != nil {
@ -1758,13 +1765,15 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return err return err
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*float64": // case "*float64":
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
x, err := strconv.ParseFloat(string(data), 64) x, err := strconv.ParseFloat(string(data), 64)
if err != nil { if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error()) return errors.New("arg " + key + " as float64: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*float32": // case "*float32":
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
var x float32 var x float32
x1, err := strconv.ParseFloat(string(data), 32) x1, err := strconv.ParseFloat(string(data), 32)
if err != nil { if err != nil {
@ -1772,7 +1781,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
x = float32(x1) x = float32(x1)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*time.Time": // case "*time.Time":
case reflect.TypeOf(&c_TIME_DEFAULT):
sdata := strings.TrimSpace(string(data)) sdata := strings.TrimSpace(string(data))
var x time.Time var x time.Time
var err error var err error
@ -1809,14 +1819,16 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
v = x v = x
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*uint64": // case "*uint64":
case reflect.TypeOf(&c_UINT64_DEFAULT):
var x uint64 var x uint64
x, err := strconv.ParseUint(string(data), 10, 64) x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
return errors.New("arg " + key + " as int: " + err.Error()) return errors.New("arg " + key + " as int: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*uint": // case "*uint":
case reflect.TypeOf(&c_UINT_DEFAULT):
var x uint var x uint
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -1824,7 +1836,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
x = uint(x1) x = uint(x1)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*uint32": // case "*uint32":
case reflect.TypeOf(&c_UINT32_DEFAULT):
var x uint32 var x uint32
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -1832,7 +1845,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
x = uint32(x1) x = uint32(x1)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*uint8": // case "*uint8":
case reflect.TypeOf(&c_UINT8_DEFAULT):
var x uint8 var x uint8
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -1840,7 +1854,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
x = uint8(x1) x = uint8(x1)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*uint16": // case "*uint16":
case reflect.TypeOf(&c_UINT16_DEFAULT):
var x uint16 var x uint16
x1, err := strconv.ParseUint(string(data), 10, 64) x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil { if err != nil {
@ -1848,7 +1863,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
} }
x = uint16(x1) x = uint16(x1)
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*int64": // case "*int64":
case reflect.TypeOf(&c_INT64_DEFAULT):
sdata := string(data) sdata := string(data)
var x int64 var x int64
var err error var err error
@ -1872,7 +1888,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error()) return errors.New("arg " + key + " as int: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*int": // case "*int":
case reflect.TypeOf(&c_INT_DEFAULT):
sdata := string(data) sdata := string(data)
var x int var x int
var x1 int64 var x1 int64
@ -1900,7 +1917,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error()) return errors.New("arg " + key + " as int: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*int32": // case "*int32":
case reflect.TypeOf(&c_INT32_DEFAULT):
sdata := string(data) sdata := string(data)
var x int32 var x int32
var x1 int64 var x1 int64
@ -1928,7 +1946,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error()) return errors.New("arg " + key + " as int: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*int8": // case "*int8":
case reflect.TypeOf(&c_INT8_DEFAULT):
sdata := string(data) sdata := string(data)
var x int8 var x int8
var x1 int64 var x1 int64
@ -1956,7 +1975,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error()) return errors.New("arg " + key + " as int: " + err.Error())
} }
fieldValue.Set(reflect.ValueOf(&x)) fieldValue.Set(reflect.ValueOf(&x))
case "*int16": // case "*int16":
case reflect.TypeOf(&c_INT16_DEFAULT):
sdata := string(data) sdata := string(data)
var x int16 var x int16
var x1 int64 var x1 int64
@ -2494,6 +2514,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
} }
var condition = "" var condition = ""
session.Statement.processIdParam()
st := session.Statement st := session.Statement
defer session.Statement.Init() defer session.Statement.Init()
if st.WhereStr != "" { if st.WhereStr != "" {
@ -2680,6 +2701,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
false, session.Statement.allUseBool, session.Statement.boolColumnMap) false, session.Statement.allUseBool, session.Statement.boolColumnMap)
var condition = "" var condition = ""
session.Statement.processIdParam()
if session.Statement.WhereStr != "" { if session.Statement.WhereStr != "" {
condition = session.Statement.WhereStr condition = session.Statement.WhereStr
if len(colNames) > 0 { if len(colNames) > 0 {

View File

@ -9,6 +9,27 @@ import (
"time" "time"
) )
// !nashtsai! treat following var as interal const values
var (
c_EMPTY_STRING = ""
c_BOOL_DEFAULT = false
c_COMPLEX64_DEFAULT = complex64(0)
c_COMPLEX128_DEFAULT = complex128(0)
c_FLOAT32_DEFAULT = float32(0)
c_FLOAT64_DEFAULT = float64(0)
c_INT64_DEFAULT = int64(0)
c_UINT64_DEFAULT = uint64(0)
c_INT32_DEFAULT = int32(0)
c_UINT32_DEFAULT = uint32(0)
c_INT16_DEFAULT = int16(0)
c_UINT16_DEFAULT = uint16(0)
c_INT8_DEFAULT = int8(0)
c_UINT8_DEFAULT = uint8(0)
c_INT_DEFAULT = int(0)
c_UINT_DEFAULT = uint(0)
c_TIME_DEFAULT time.Time = time.Unix(0, 0)
)
// statement save all the sql info for executing SQL // statement save all the sql info for executing SQL
type Statement struct { type Statement struct {
RefTable *Table RefTable *Table
@ -16,6 +37,7 @@ type Statement struct {
Start int Start int
LimitN int LimitN int
WhereStr string WhereStr string
IdParam *PK
Params []interface{} Params []interface{}
OrderStr string OrderStr string
JoinStr string JoinStr string
@ -394,15 +416,40 @@ func (statement *Statement) TableName() string {
return "" return ""
} }
// Generate "Where id = ? " statment // Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) Id(id int64) *Statement { func (statement *Statement) Id(id interface{}) *Statement {
if statement.WhereStr == "" {
statement.WhereStr = "(id)=?" idValue := reflect.ValueOf(id)
statement.Params = []interface{}{id} idType := reflect.TypeOf(idValue.Interface())
} else {
statement.WhereStr = statement.WhereStr + " AND (id)=?" switch idType {
statement.Params = append(statement.Params, id) case reflect.TypeOf(&PK{}):
if pkPtr, ok := (id).(*PK); ok {
statement.IdParam = pkPtr
} }
case reflect.TypeOf(PK{}):
if pk, ok := (id).(PK); ok {
statement.IdParam = &pk
}
default:
// TODO treat as int primitve for now, need to handle type check
statement.IdParam = &PK{id}
// !nashtsai! REVIEW although it will be user's mistake if called Id() twice with
// different value and Id should be PK's field name, however, at this stage probably
// can't tell which table is gonna be used
// if statement.WhereStr == "" {
// statement.WhereStr = "(id)=?"
// statement.Params = []interface{}{id}
// } else {
// // TODO what if id param has already passed
// statement.WhereStr = statement.WhereStr + " AND (id)=?"
// statement.Params = append(statement.Params, id)
// }
}
// !nashtsai! perhaps no need to validate pk values' type just let sql complaint happen
return statement return statement
} }
@ -559,14 +606,37 @@ func (statement *Statement) genColumnStr() string {
return strings.Join(colNames, ", ") return strings.Join(colNames, ", ")
} }
func (statement *Statement) genCreateSQL() string { func (statement *Statement) genCreateTableSQL() string {
sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " (" sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " ("
pkList := []string{}
for _, colName := range statement.RefTable.ColumnsSeq { for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[colName] col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey {
pkList = append(pkList, col.Name)
}
}
statement.Engine.LogDebug("len:", len(pkList))
for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(statement.Engine.dialect) sql += col.String(statement.Engine.dialect)
} else {
sql += col.stringNoPk(statement.Engine.dialect)
}
sql = strings.TrimSpace(sql) sql = strings.TrimSpace(sql)
sql += ", " sql += ", "
} }
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += strings.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2] + ")" sql = sql[:len(sql)-2] + ")"
if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" { if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" {
sql += " ENGINE=" + statement.StoreEngine sql += " ENGINE=" + statement.StoreEngine
@ -702,11 +772,14 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.IsDistinct { if statement.IsDistinct {
distinct = "DISTINCT " distinct = "DISTINCT "
} }
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr, a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr,
statement.Engine.Quote(statement.TableName())) statement.Engine.Quote(statement.TableName()))
if statement.JoinStr != "" { if statement.JoinStr != "" {
a = fmt.Sprintf("%v %v", a, statement.JoinStr) a = fmt.Sprintf("%v %v", a, statement.JoinStr)
} }
statement.processIdParam()
if statement.WhereStr != "" { if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ConditionStr != "" { if statement.ConditionStr != "" {
@ -732,3 +805,34 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
} }
return return
} }
func (statement *Statement) processIdParam() {
if statement.IdParam != nil {
i := 0
colCnt := len(statement.RefTable.ColumnsSeq)
for _, elem := range *(statement.IdParam) {
for ; i < colCnt; i++ {
colName := statement.RefTable.ColumnsSeq[i]
col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), elem)
i++
break
}
}
}
// !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it
// as empty string for now, so this will result sql exec failed instead of unexpected
// false update/delete
for ; i < colCnt; i++ {
colName := statement.RefTable.ColumnsSeq[i]
col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), "")
}
}
}
}

View File

@ -155,25 +155,25 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
} }
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) { func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
typeStr := t.String()
has = true has = true
switch typeStr { switch t {
case "*string": case reflect.TypeOf(&c_EMPTY_STRING):
st = SQLType{Varchar, 255, 0} st = SQLType{Varchar, 255, 0}
case "*bool": return
case reflect.TypeOf(&c_BOOL_DEFAULT):
st = SQLType{Bool, 0, 0} st = SQLType{Bool, 0, 0}
case "*complex64", "*complex128": case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case "*float32": case reflect.TypeOf(&c_FLOAT32_DEFAULT):
st = SQLType{Float, 0, 0} st = SQLType{Float, 0, 0}
case "*float64": case reflect.TypeOf(&c_FLOAT64_DEFAULT):
st = SQLType{Double, 0, 0} st = SQLType{Double, 0, 0}
case "*int64", "*uint64": case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
st = SQLType{BigInt, 0, 0} st = SQLType{BigInt, 0, 0}
case "*time.Time": case reflect.TypeOf(&c_TIME_DEFAULT):
st = SQLType{DateTime, 0, 0} st = SQLType{DateTime, 0, 0}
case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8": case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
st = SQLType{Int, 0, 0} st = SQLType{Int, 0, 0}
default: default:
has = false has = false
@ -265,11 +265,28 @@ func (col *Column) String(d dialect) string {
if col.IsPrimaryKey { if col.IsPrimaryKey {
sql += "PRIMARY KEY " sql += "PRIMARY KEY "
}
if col.IsAutoIncrement { if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " " sql += d.AutoIncrStr() + " "
} }
}
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
func (col *Column) stringNoPk(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " "
if col.Nullable { if col.Nullable {
sql += "NULL " sql += "NULL "