implemented composite key
This commit is contained in:
parent
8b84f5d692
commit
f2b20d510c
137
base_test.go
137
base_test.go
|
@ -1892,21 +1892,21 @@ func (p *ProcessorsStruct) AfterDelete() {
|
|||
}
|
||||
|
||||
func testProcessors(engine *Engine, t *testing.T) {
|
||||
tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
// tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
tempEngine.ShowSQL = true
|
||||
err = tempEngine.DropTables(&ProcessorsStruct{})
|
||||
engine.ShowSQL = true
|
||||
err := engine.DropTables(&ProcessorsStruct{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
p := &ProcessorsStruct{}
|
||||
|
||||
err = tempEngine.CreateTables(&ProcessorsStruct{})
|
||||
err = engine.CreateTables(&ProcessorsStruct{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -1928,7 +1928,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
_, err = tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
|
||||
_, err = engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -1948,7 +1948,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
}
|
||||
|
||||
p2 := &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(p.Id).Get(p2)
|
||||
_, err = engine.Id(p.Id).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -1987,7 +1987,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
|
||||
p = p2 // reset
|
||||
|
||||
_, err = tempEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
|
||||
_, err = engine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2007,7 +2007,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
}
|
||||
|
||||
p2 = &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(p.Id).Get(p2)
|
||||
_, err = engine.Id(p.Id).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2045,7 +2045,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
}
|
||||
|
||||
p = p2 // reset
|
||||
_, err = tempEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
|
||||
_, err = engine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2069,7 +2069,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
pslice := make([]*ProcessorsStruct, 0)
|
||||
pslice = append(pslice, &ProcessorsStruct{})
|
||||
pslice = append(pslice, &ProcessorsStruct{})
|
||||
cnt, err := tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice)
|
||||
cnt, err := engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2095,7 +2095,7 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
|
||||
for _, elem := range pslice {
|
||||
p = &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(elem.Id).Get(p)
|
||||
_, err = engine.Id(elem.Id).Get(p)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2118,27 +2118,27 @@ func testProcessors(engine *Engine, t *testing.T) {
|
|||
}
|
||||
|
||||
func testProcessorsTx(engine *Engine, t *testing.T) {
|
||||
tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
|
||||
// tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
|
||||
// if err != nil {
|
||||
// t.Error(err)
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
// tempEngine.ShowSQL = true
|
||||
err := engine.DropTables(&ProcessorsStruct{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tempEngine.ShowSQL = true
|
||||
err = tempEngine.DropTables(&ProcessorsStruct{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = tempEngine.CreateTables(&ProcessorsStruct{})
|
||||
err = engine.CreateTables(&ProcessorsStruct{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// test insert processors with tx rollback
|
||||
session := tempEngine.NewSession()
|
||||
session := engine.NewSession()
|
||||
err = session.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -2200,7 +2200,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
}
|
||||
session.Close()
|
||||
p2 := &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(p.Id).Get(p2)
|
||||
_, err = engine.Id(p.Id).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2214,7 +2214,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
// --
|
||||
|
||||
// test insert processors with tx commit
|
||||
session = tempEngine.NewSession()
|
||||
session = engine.NewSession()
|
||||
err = session.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -2261,7 +2261,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
}
|
||||
session.Close()
|
||||
p2 = &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(p.Id).Get(p2)
|
||||
_, err = engine.Id(p.Id).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2283,7 +2283,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
// --
|
||||
|
||||
// test update processors with tx rollback
|
||||
session = tempEngine.NewSession()
|
||||
session = engine.NewSession()
|
||||
err = session.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -2347,7 +2347,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
|
||||
session.Close()
|
||||
p2 = &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(insertedId).Get(p2)
|
||||
_, err = engine.Id(insertedId).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2368,7 +2368,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
// --
|
||||
|
||||
// test update processors with tx commit
|
||||
session = tempEngine.NewSession()
|
||||
session = engine.NewSession()
|
||||
err = session.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -2415,7 +2415,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
}
|
||||
session.Close()
|
||||
p2 = &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(insertedId).Get(p2)
|
||||
_, err = engine.Id(insertedId).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2436,7 +2436,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
// --
|
||||
|
||||
// test delete processors with tx rollback
|
||||
session = tempEngine.NewSession()
|
||||
session = engine.NewSession()
|
||||
err = session.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -2500,7 +2500,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
session.Close()
|
||||
|
||||
p2 = &ProcessorsStruct{}
|
||||
_, err = tempEngine.Id(insertedId).Get(p2)
|
||||
_, err = engine.Id(insertedId).Get(p2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
|
@ -2521,7 +2521,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
|
|||
// --
|
||||
|
||||
// test delete processors with tx commit
|
||||
session = tempEngine.NewSession()
|
||||
session = engine.NewSession()
|
||||
err = session.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
@ -3253,6 +3253,71 @@ func testNullValue(engine *Engine, t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
type CompositeKey struct {
|
||||
Id1 int64 `xorm:"id1 pk"`
|
||||
Id2 int64 `xorm:"id2 pk"`
|
||||
UpdateStr string
|
||||
}
|
||||
|
||||
func testCompositeKey(engine *Engine, t *testing.T) {
|
||||
|
||||
err := engine.DropTables(&CompositeKey{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = engine.CreateTables(&CompositeKey{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
cnt, err := engine.Insert(&CompositeKey{11, 22, ""})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if cnt != 1 {
|
||||
t.Error(errors.New("failed to insert CompositeKey{11, 22}"))
|
||||
}
|
||||
|
||||
cnt, err = engine.Insert(&CompositeKey{11, 22, ""})
|
||||
if err == nil || cnt == 1 {
|
||||
t.Error(errors.New("inserted CompositeKey{11, 22}"))
|
||||
}
|
||||
|
||||
var compositeKeyVal CompositeKey
|
||||
has, err := engine.Id(PK{11, 22}).Get(&compositeKeyVal)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if !has {
|
||||
t.Error(errors.New("can't get CompositeKey{11, 22}"))
|
||||
}
|
||||
|
||||
// test passing PK ptr, this test seem failed withCache
|
||||
has, err = engine.Id(&PK{11, 22}).Get(&compositeKeyVal)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if !has {
|
||||
t.Error(errors.New("can't get CompositeKey{11, 22}"))
|
||||
}
|
||||
|
||||
compositeKeyVal = CompositeKey{UpdateStr:"test1"}
|
||||
cnt, err = engine.Id(PK{11, 22}).Update(&compositeKeyVal)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if cnt != 1 {
|
||||
t.Error(errors.New("can't update CompositeKey{11, 22}"))
|
||||
}
|
||||
|
||||
cnt, err = engine.Id(PK{11, 22}).Delete(&CompositeKey{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if cnt != 1 {
|
||||
t.Error(errors.New("can't delete CompositeKey{11, 22}"))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func testAll(engine *Engine, t *testing.T) {
|
||||
fmt.Println("-------------- directCreateTable --------------")
|
||||
directCreateTable(engine, t)
|
||||
|
@ -3363,4 +3428,6 @@ func testAll3(engine *Engine, t *testing.T) {
|
|||
testPointerData(engine, t)
|
||||
fmt.Println("-------------- insert null data --------------")
|
||||
testNullValue(engine, t)
|
||||
fmt.Println("-------------- testCompositeKey --------------")
|
||||
testCompositeKey(engine, t)
|
||||
}
|
||||
|
|
|
@ -40,6 +40,8 @@ type dialect interface {
|
|||
GetIndexes(tableName string) (map[string]*Index, error)
|
||||
}
|
||||
|
||||
type PK []interface{}
|
||||
|
||||
// Engine is the major struct of xorm, it means a database manager.
|
||||
// Commonly, an application only need one engine
|
||||
type Engine struct {
|
||||
|
@ -269,7 +271,7 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
|
|||
}
|
||||
|
||||
// Id mehtod provoide a condition as (id) = ?
|
||||
func (engine *Engine) Id(id int64) *Session {
|
||||
func (engine *Engine) Id(id interface{}) *Session {
|
||||
session := engine.NewSession()
|
||||
session.IsAutoClose = true
|
||||
return session.Id(id)
|
||||
|
|
65
session.go
65
session.go
|
@ -90,7 +90,7 @@ func (session *Session) Or(querystring string, args ...interface{}) *Session {
|
|||
}
|
||||
|
||||
// Method Id provides converting id as a query condition
|
||||
func (session *Session) Id(id int64) *Session {
|
||||
func (session *Session) Id(id interface{}) *Session {
|
||||
session.Statement.Id(id)
|
||||
return session
|
||||
}
|
||||
|
@ -490,7 +490,8 @@ func (session *Session) CreateUniques(bean interface{}) error {
|
|||
}
|
||||
|
||||
func (session *Session) createOneTable() error {
|
||||
sql := session.Statement.genCreateSQL()
|
||||
sql := session.Statement.genCreateTableSQL()
|
||||
session.Engine.LogDebug("create table sql: [", sql, "]")
|
||||
_, err := session.exec(sql)
|
||||
return err
|
||||
}
|
||||
|
@ -1580,6 +1581,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
var v interface{}
|
||||
key := col.Name
|
||||
fieldType := fieldValue.Type()
|
||||
|
||||
|
||||
//fmt.Println("column name:", key, ", fieldType:", fieldType.String())
|
||||
switch fieldType.Kind() {
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
|
@ -1730,19 +1733,22 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
}
|
||||
case reflect.Ptr:
|
||||
// !nashtsai! TODO merge duplicated codes above
|
||||
typeStr := fieldType.String()
|
||||
switch typeStr {
|
||||
case "*string":
|
||||
//typeStr := fieldType.String()
|
||||
switch fieldType {
|
||||
// case "*string":
|
||||
case reflect.TypeOf(&c_EMPTY_STRING):
|
||||
x := string(data)
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*bool":
|
||||
// case "*bool":
|
||||
case reflect.TypeOf(&c_BOOL_DEFAULT):
|
||||
d := string(data)
|
||||
v, err := strconv.ParseBool(d)
|
||||
if err != nil {
|
||||
return errors.New("arg " + key + " as bool: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&v))
|
||||
case "*complex64":
|
||||
// case "*complex64":
|
||||
case reflect.TypeOf(&c_COMPLEX64_DEFAULT):
|
||||
var x complex64
|
||||
err := json.Unmarshal(data, &x)
|
||||
if err != nil {
|
||||
|
@ -1750,7 +1756,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
return err
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*complex128":
|
||||
// case "*complex128":
|
||||
case reflect.TypeOf(&c_COMPLEX128_DEFAULT):
|
||||
var x complex128
|
||||
err := json.Unmarshal(data, &x)
|
||||
if err != nil {
|
||||
|
@ -1758,13 +1765,15 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
return err
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*float64":
|
||||
// case "*float64":
|
||||
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
|
||||
x, err := strconv.ParseFloat(string(data), 64)
|
||||
if err != nil {
|
||||
return errors.New("arg " + key + " as float64: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*float32":
|
||||
// case "*float32":
|
||||
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
|
||||
var x float32
|
||||
x1, err := strconv.ParseFloat(string(data), 32)
|
||||
if err != nil {
|
||||
|
@ -1772,7 +1781,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
}
|
||||
x = float32(x1)
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*time.Time":
|
||||
// case "*time.Time":
|
||||
case reflect.TypeOf(&c_TIME_DEFAULT):
|
||||
sdata := strings.TrimSpace(string(data))
|
||||
var x time.Time
|
||||
var err error
|
||||
|
@ -1809,14 +1819,16 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
|
||||
v = x
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*uint64":
|
||||
// case "*uint64":
|
||||
case reflect.TypeOf(&c_UINT64_DEFAULT):
|
||||
var x uint64
|
||||
x, err := strconv.ParseUint(string(data), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New("arg " + key + " as int: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*uint":
|
||||
// case "*uint":
|
||||
case reflect.TypeOf(&c_UINT_DEFAULT):
|
||||
var x uint
|
||||
x1, err := strconv.ParseUint(string(data), 10, 64)
|
||||
if err != nil {
|
||||
|
@ -1824,7 +1836,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
}
|
||||
x = uint(x1)
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*uint32":
|
||||
// case "*uint32":
|
||||
case reflect.TypeOf(&c_UINT32_DEFAULT):
|
||||
var x uint32
|
||||
x1, err := strconv.ParseUint(string(data), 10, 64)
|
||||
if err != nil {
|
||||
|
@ -1832,7 +1845,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
}
|
||||
x = uint32(x1)
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*uint8":
|
||||
// case "*uint8":
|
||||
case reflect.TypeOf(&c_UINT8_DEFAULT):
|
||||
var x uint8
|
||||
x1, err := strconv.ParseUint(string(data), 10, 64)
|
||||
if err != nil {
|
||||
|
@ -1840,7 +1854,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
}
|
||||
x = uint8(x1)
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*uint16":
|
||||
// case "*uint16":
|
||||
case reflect.TypeOf(&c_UINT16_DEFAULT):
|
||||
var x uint16
|
||||
x1, err := strconv.ParseUint(string(data), 10, 64)
|
||||
if err != nil {
|
||||
|
@ -1848,7 +1863,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
}
|
||||
x = uint16(x1)
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*int64":
|
||||
// case "*int64":
|
||||
case reflect.TypeOf(&c_INT64_DEFAULT):
|
||||
sdata := string(data)
|
||||
var x int64
|
||||
var err error
|
||||
|
@ -1872,7 +1888,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
return errors.New("arg " + key + " as int: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*int":
|
||||
// case "*int":
|
||||
case reflect.TypeOf(&c_INT_DEFAULT):
|
||||
sdata := string(data)
|
||||
var x int
|
||||
var x1 int64
|
||||
|
@ -1900,7 +1917,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
return errors.New("arg " + key + " as int: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*int32":
|
||||
// case "*int32":
|
||||
case reflect.TypeOf(&c_INT32_DEFAULT):
|
||||
sdata := string(data)
|
||||
var x int32
|
||||
var x1 int64
|
||||
|
@ -1928,7 +1946,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
return errors.New("arg " + key + " as int: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*int8":
|
||||
// case "*int8":
|
||||
case reflect.TypeOf(&c_INT8_DEFAULT):
|
||||
sdata := string(data)
|
||||
var x int8
|
||||
var x1 int64
|
||||
|
@ -1956,7 +1975,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
|
|||
return errors.New("arg " + key + " as int: " + err.Error())
|
||||
}
|
||||
fieldValue.Set(reflect.ValueOf(&x))
|
||||
case "*int16":
|
||||
// case "*int16":
|
||||
case reflect.TypeOf(&c_INT16_DEFAULT):
|
||||
sdata := string(data)
|
||||
var x int16
|
||||
var x1 int64
|
||||
|
@ -2494,6 +2514,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
|
|||
}
|
||||
|
||||
var condition = ""
|
||||
session.Statement.processIdParam()
|
||||
st := session.Statement
|
||||
defer session.Statement.Init()
|
||||
if st.WhereStr != "" {
|
||||
|
@ -2680,6 +2701,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
|
|||
false, session.Statement.allUseBool, session.Statement.boolColumnMap)
|
||||
|
||||
var condition = ""
|
||||
|
||||
session.Statement.processIdParam()
|
||||
if session.Statement.WhereStr != "" {
|
||||
condition = session.Statement.WhereStr
|
||||
if len(colNames) > 0 {
|
||||
|
|
128
statement.go
128
statement.go
|
@ -9,6 +9,27 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// !nashtsai! treat following var as interal const values
|
||||
var (
|
||||
c_EMPTY_STRING = ""
|
||||
c_BOOL_DEFAULT = false
|
||||
c_COMPLEX64_DEFAULT = complex64(0)
|
||||
c_COMPLEX128_DEFAULT = complex128(0)
|
||||
c_FLOAT32_DEFAULT = float32(0)
|
||||
c_FLOAT64_DEFAULT = float64(0)
|
||||
c_INT64_DEFAULT = int64(0)
|
||||
c_UINT64_DEFAULT = uint64(0)
|
||||
c_INT32_DEFAULT = int32(0)
|
||||
c_UINT32_DEFAULT = uint32(0)
|
||||
c_INT16_DEFAULT = int16(0)
|
||||
c_UINT16_DEFAULT = uint16(0)
|
||||
c_INT8_DEFAULT = int8(0)
|
||||
c_UINT8_DEFAULT = uint8(0)
|
||||
c_INT_DEFAULT = int(0)
|
||||
c_UINT_DEFAULT = uint(0)
|
||||
c_TIME_DEFAULT time.Time = time.Unix(0, 0)
|
||||
)
|
||||
|
||||
// statement save all the sql info for executing SQL
|
||||
type Statement struct {
|
||||
RefTable *Table
|
||||
|
@ -16,6 +37,7 @@ type Statement struct {
|
|||
Start int
|
||||
LimitN int
|
||||
WhereStr string
|
||||
IdParam *PK
|
||||
Params []interface{}
|
||||
OrderStr string
|
||||
JoinStr string
|
||||
|
@ -256,7 +278,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
|
|||
if fieldValue.IsNil() {
|
||||
if includeNil {
|
||||
args = append(args, nil)
|
||||
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
|
||||
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
|
||||
}
|
||||
continue
|
||||
} else if !fieldValue.IsValid() {
|
||||
|
@ -376,7 +398,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
|
|||
}
|
||||
|
||||
args = append(args, val)
|
||||
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
|
||||
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
|
||||
}
|
||||
|
||||
return colNames, args
|
||||
|
@ -394,15 +416,40 @@ func (statement *Statement) TableName() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
// Generate "Where id = ? " statment
|
||||
func (statement *Statement) Id(id int64) *Statement {
|
||||
if statement.WhereStr == "" {
|
||||
statement.WhereStr = "(id)=?"
|
||||
statement.Params = []interface{}{id}
|
||||
} else {
|
||||
statement.WhereStr = statement.WhereStr + " AND (id)=?"
|
||||
statement.Params = append(statement.Params, id)
|
||||
// Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
|
||||
func (statement *Statement) Id(id interface{}) *Statement {
|
||||
|
||||
idValue := reflect.ValueOf(id)
|
||||
idType := reflect.TypeOf(idValue.Interface())
|
||||
|
||||
switch idType {
|
||||
case reflect.TypeOf(&PK{}):
|
||||
if pkPtr, ok := (id).(*PK); ok {
|
||||
statement.IdParam = pkPtr
|
||||
}
|
||||
case reflect.TypeOf(PK{}):
|
||||
if pk, ok := (id).(PK); ok {
|
||||
statement.IdParam = &pk
|
||||
}
|
||||
default:
|
||||
// TODO treat as int primitve for now, need to handle type check
|
||||
statement.IdParam = &PK{id}
|
||||
|
||||
// !nashtsai! REVIEW although it will be user's mistake if called Id() twice with
|
||||
// different value and Id should be PK's field name, however, at this stage probably
|
||||
// can't tell which table is gonna be used
|
||||
// if statement.WhereStr == "" {
|
||||
// statement.WhereStr = "(id)=?"
|
||||
// statement.Params = []interface{}{id}
|
||||
// } else {
|
||||
// // TODO what if id param has already passed
|
||||
// statement.WhereStr = statement.WhereStr + " AND (id)=?"
|
||||
// statement.Params = append(statement.Params, id)
|
||||
// }
|
||||
}
|
||||
|
||||
// !nashtsai! perhaps no need to validate pk values' type just let sql complaint happen
|
||||
|
||||
return statement
|
||||
}
|
||||
|
||||
|
@ -559,14 +606,37 @@ func (statement *Statement) genColumnStr() string {
|
|||
return strings.Join(colNames, ", ")
|
||||
}
|
||||
|
||||
func (statement *Statement) genCreateSQL() string {
|
||||
func (statement *Statement) genCreateTableSQL() string {
|
||||
sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " ("
|
||||
|
||||
|
||||
pkList := []string{}
|
||||
|
||||
for _, colName := range statement.RefTable.ColumnsSeq {
|
||||
col := statement.RefTable.Columns[colName]
|
||||
sql += col.String(statement.Engine.dialect)
|
||||
if col.IsPrimaryKey {
|
||||
pkList = append(pkList, col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
statement.Engine.LogDebug("len:", len(pkList))
|
||||
for _, colName := range statement.RefTable.ColumnsSeq {
|
||||
col := statement.RefTable.Columns[colName]
|
||||
if col.IsPrimaryKey && len(pkList) == 1 {
|
||||
sql += col.String(statement.Engine.dialect)
|
||||
} else {
|
||||
sql += col.stringNoPk(statement.Engine.dialect)
|
||||
}
|
||||
sql = strings.TrimSpace(sql)
|
||||
sql += ", "
|
||||
}
|
||||
|
||||
if len(pkList) > 1 {
|
||||
sql += "PRIMARY KEY ( "
|
||||
sql += strings.Join(pkList, ",")
|
||||
sql += " ), "
|
||||
}
|
||||
|
||||
sql = sql[:len(sql)-2] + ")"
|
||||
if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" {
|
||||
sql += " ENGINE=" + statement.StoreEngine
|
||||
|
@ -702,11 +772,14 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
|
|||
if statement.IsDistinct {
|
||||
distinct = "DISTINCT "
|
||||
}
|
||||
|
||||
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
|
||||
a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr,
|
||||
statement.Engine.Quote(statement.TableName()))
|
||||
if statement.JoinStr != "" {
|
||||
a = fmt.Sprintf("%v %v", a, statement.JoinStr)
|
||||
}
|
||||
statement.processIdParam()
|
||||
if statement.WhereStr != "" {
|
||||
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
|
||||
if statement.ConditionStr != "" {
|
||||
|
@ -732,3 +805,34 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (statement *Statement) processIdParam() {
|
||||
|
||||
if statement.IdParam != nil {
|
||||
i := 0
|
||||
colCnt := len(statement.RefTable.ColumnsSeq)
|
||||
for _, elem := range *(statement.IdParam) {
|
||||
for ; i < colCnt; i++ {
|
||||
colName := statement.RefTable.ColumnsSeq[i]
|
||||
col := statement.RefTable.Columns[colName]
|
||||
if col.IsPrimaryKey {
|
||||
statement.And(fmt.Sprintf("%v=?", col.Name), elem)
|
||||
i++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it
|
||||
// as empty string for now, so this will result sql exec failed instead of unexpected
|
||||
// false update/delete
|
||||
for ; i < colCnt; i++ {
|
||||
colName := statement.RefTable.ColumnsSeq[i]
|
||||
col := statement.RefTable.Columns[colName]
|
||||
if col.IsPrimaryKey {
|
||||
statement.And(fmt.Sprintf("%v=?", col.Name), "")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
41
table.go
41
table.go
|
@ -155,25 +155,25 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
|
|||
}
|
||||
|
||||
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
|
||||
typeStr := t.String()
|
||||
has = true
|
||||
|
||||
switch typeStr {
|
||||
case "*string":
|
||||
switch t {
|
||||
case reflect.TypeOf(&c_EMPTY_STRING):
|
||||
st = SQLType{Varchar, 255, 0}
|
||||
case "*bool":
|
||||
return
|
||||
case reflect.TypeOf(&c_BOOL_DEFAULT):
|
||||
st = SQLType{Bool, 0, 0}
|
||||
case "*complex64", "*complex128":
|
||||
case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
|
||||
st = SQLType{Varchar, 64, 0}
|
||||
case "*float32":
|
||||
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
|
||||
st = SQLType{Float, 0, 0}
|
||||
case "*float64":
|
||||
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
|
||||
st = SQLType{Double, 0, 0}
|
||||
case "*int64", "*uint64":
|
||||
case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
|
||||
st = SQLType{BigInt, 0, 0}
|
||||
case "*time.Time":
|
||||
case reflect.TypeOf(&c_TIME_DEFAULT):
|
||||
st = SQLType{DateTime, 0, 0}
|
||||
case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8":
|
||||
case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
|
||||
st = SQLType{Int, 0, 0}
|
||||
default:
|
||||
has = false
|
||||
|
@ -265,12 +265,29 @@ func (col *Column) String(d dialect) string {
|
|||
|
||||
if col.IsPrimaryKey {
|
||||
sql += "PRIMARY KEY "
|
||||
if col.IsAutoIncrement {
|
||||
sql += d.AutoIncrStr() + " "
|
||||
}
|
||||
}
|
||||
|
||||
if col.IsAutoIncrement {
|
||||
sql += d.AutoIncrStr() + " "
|
||||
if col.Nullable {
|
||||
sql += "NULL "
|
||||
} else {
|
||||
sql += "NOT NULL "
|
||||
}
|
||||
|
||||
if col.Default != "" {
|
||||
sql += "DEFAULT " + col.Default + " "
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func (col *Column) stringNoPk(d dialect) string {
|
||||
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
|
||||
|
||||
sql += d.SqlType(col) + " "
|
||||
|
||||
if col.Nullable {
|
||||
sql += "NULL "
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue