merge composite key

This commit is contained in:
Lunny Xiao 2013-12-17 09:38:20 +08:00
parent 81a4c102b9
commit 99c7031b50
5 changed files with 294 additions and 81 deletions

View File

@ -1892,21 +1892,21 @@ func (p *ProcessorsStruct) AfterDelete() {
}
func testProcessors(engine *Engine, t *testing.T) {
tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
if err != nil {
t.Error(err)
panic(err)
}
// tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
// if err != nil {
// t.Error(err)
// panic(err)
// }
tempEngine.ShowSQL = true
err = tempEngine.DropTables(&ProcessorsStruct{})
engine.ShowSQL = true
err := engine.DropTables(&ProcessorsStruct{})
if err != nil {
t.Error(err)
panic(err)
}
p := &ProcessorsStruct{}
err = tempEngine.CreateTables(&ProcessorsStruct{})
err = engine.CreateTables(&ProcessorsStruct{})
if err != nil {
t.Error(err)
panic(err)
@ -1928,7 +1928,7 @@ func testProcessors(engine *Engine, t *testing.T) {
}
}
_, err = tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
_, err = engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(p)
if err != nil {
t.Error(err)
panic(err)
@ -1948,7 +1948,7 @@ func testProcessors(engine *Engine, t *testing.T) {
}
p2 := &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2)
_, err = engine.Id(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -1987,7 +1987,7 @@ func testProcessors(engine *Engine, t *testing.T) {
p = p2 // reset
_, err = tempEngine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
_, err = engine.Before(b4UpdateFunc).After(afterUpdateFunc).Update(p)
if err != nil {
t.Error(err)
panic(err)
@ -2007,7 +2007,7 @@ func testProcessors(engine *Engine, t *testing.T) {
}
p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2)
_, err = engine.Id(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -2045,7 +2045,7 @@ func testProcessors(engine *Engine, t *testing.T) {
}
p = p2 // reset
_, err = tempEngine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
_, err = engine.Before(b4DeleteFunc).After(afterDeleteFunc).Delete(p)
if err != nil {
t.Error(err)
panic(err)
@ -2069,7 +2069,7 @@ func testProcessors(engine *Engine, t *testing.T) {
pslice := make([]*ProcessorsStruct, 0)
pslice = append(pslice, &ProcessorsStruct{})
pslice = append(pslice, &ProcessorsStruct{})
cnt, err := tempEngine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice)
cnt, err := engine.Before(b4InsertFunc).After(afterInsertFunc).Insert(&pslice)
if err != nil {
t.Error(err)
panic(err)
@ -2095,7 +2095,7 @@ func testProcessors(engine *Engine, t *testing.T) {
for _, elem := range pslice {
p = &ProcessorsStruct{}
_, err = tempEngine.Id(elem.Id).Get(p)
_, err = engine.Id(elem.Id).Get(p)
if err != nil {
t.Error(err)
panic(err)
@ -2118,27 +2118,27 @@ func testProcessors(engine *Engine, t *testing.T) {
}
func testProcessorsTx(engine *Engine, t *testing.T) {
tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
// tempEngine, err := NewEngine(engine.DriverName, engine.DataSourceName)
// if err != nil {
// t.Error(err)
// panic(err)
// }
// tempEngine.ShowSQL = true
err := engine.DropTables(&ProcessorsStruct{})
if err != nil {
t.Error(err)
panic(err)
}
tempEngine.ShowSQL = true
err = tempEngine.DropTables(&ProcessorsStruct{})
if err != nil {
t.Error(err)
panic(err)
}
err = tempEngine.CreateTables(&ProcessorsStruct{})
err = engine.CreateTables(&ProcessorsStruct{})
if err != nil {
t.Error(err)
panic(err)
}
// test insert processors with tx rollback
session := tempEngine.NewSession()
session := engine.NewSession()
err = session.Begin()
if err != nil {
t.Error(err)
@ -2200,7 +2200,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
}
session.Close()
p2 := &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2)
_, err = engine.Id(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -2214,7 +2214,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// --
// test insert processors with tx commit
session = tempEngine.NewSession()
session = engine.NewSession()
err = session.Begin()
if err != nil {
t.Error(err)
@ -2261,7 +2261,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
}
session.Close()
p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(p.Id).Get(p2)
_, err = engine.Id(p.Id).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -2283,7 +2283,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// --
// test update processors with tx rollback
session = tempEngine.NewSession()
session = engine.NewSession()
err = session.Begin()
if err != nil {
t.Error(err)
@ -2347,7 +2347,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
session.Close()
p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(insertedId).Get(p2)
_, err = engine.Id(insertedId).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -2368,7 +2368,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// --
// test update processors with tx commit
session = tempEngine.NewSession()
session = engine.NewSession()
err = session.Begin()
if err != nil {
t.Error(err)
@ -2415,7 +2415,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
}
session.Close()
p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(insertedId).Get(p2)
_, err = engine.Id(insertedId).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -2436,7 +2436,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// --
// test delete processors with tx rollback
session = tempEngine.NewSession()
session = engine.NewSession()
err = session.Begin()
if err != nil {
t.Error(err)
@ -2500,7 +2500,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
session.Close()
p2 = &ProcessorsStruct{}
_, err = tempEngine.Id(insertedId).Get(p2)
_, err = engine.Id(insertedId).Get(p2)
if err != nil {
t.Error(err)
panic(err)
@ -2521,7 +2521,7 @@ func testProcessorsTx(engine *Engine, t *testing.T) {
// --
// test delete processors with tx commit
session = tempEngine.NewSession()
session = engine.NewSession()
err = session.Begin()
if err != nil {
t.Error(err)
@ -3253,6 +3253,71 @@ func testNullValue(engine *Engine, t *testing.T) {
}
type CompositeKey struct {
Id1 int64 `xorm:"id1 pk"`
Id2 int64 `xorm:"id2 pk"`
UpdateStr string
}
func testCompositeKey(engine *Engine, t *testing.T) {
err := engine.DropTables(&CompositeKey{})
if err != nil {
t.Error(err)
panic(err)
}
err = engine.CreateTables(&CompositeKey{})
if err != nil {
t.Error(err)
panic(err)
}
cnt, err := engine.Insert(&CompositeKey{11, 22, ""})
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("failed to insert CompositeKey{11, 22}"))
}
cnt, err = engine.Insert(&CompositeKey{11, 22, ""})
if err == nil || cnt == 1 {
t.Error(errors.New("inserted CompositeKey{11, 22}"))
}
var compositeKeyVal CompositeKey
has, err := engine.Id(PK{11, 22}).Get(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if !has {
t.Error(errors.New("can't get CompositeKey{11, 22}"))
}
// test passing PK ptr, this test seem failed withCache
has, err = engine.Id(&PK{11, 22}).Get(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if !has {
t.Error(errors.New("can't get CompositeKey{11, 22}"))
}
compositeKeyVal = CompositeKey{UpdateStr:"test1"}
cnt, err = engine.Id(PK{11, 22}).Update(&compositeKeyVal)
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't update CompositeKey{11, 22}"))
}
cnt, err = engine.Id(PK{11, 22}).Delete(&CompositeKey{})
if err != nil {
t.Error(err)
} else if cnt != 1 {
t.Error(errors.New("can't delete CompositeKey{11, 22}"))
}
}
func testAll(engine *Engine, t *testing.T) {
fmt.Println("-------------- directCreateTable --------------")
directCreateTable(engine, t)
@ -3363,4 +3428,6 @@ func testAll3(engine *Engine, t *testing.T) {
testPointerData(engine, t)
fmt.Println("-------------- insert null data --------------")
testNullValue(engine, t)
fmt.Println("-------------- testCompositeKey --------------")
testCompositeKey(engine, t)
}

View File

@ -40,6 +40,8 @@ type dialect interface {
GetIndexes(tableName string) (map[string]*Index, error)
}
type PK []interface{}
// Engine is the major struct of xorm, it means a database manager.
// Commonly, an application only need one engine
type Engine struct {
@ -269,7 +271,7 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Session {
}
// Id mehtod provoide a condition as (id) = ?
func (engine *Engine) Id(id int64) *Session {
func (engine *Engine) Id(id interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
return session.Id(id)

View File

@ -90,7 +90,7 @@ func (session *Session) Or(querystring string, args ...interface{}) *Session {
}
// Method Id provides converting id as a query condition
func (session *Session) Id(id int64) *Session {
func (session *Session) Id(id interface{}) *Session {
session.Statement.Id(id)
return session
}
@ -490,7 +490,8 @@ func (session *Session) CreateUniques(bean interface{}) error {
}
func (session *Session) createOneTable() error {
sql := session.Statement.genCreateSQL()
sql := session.Statement.genCreateTableSQL()
session.Engine.LogDebug("create table sql: [", sql, "]")
_, err := session.exec(sql)
return err
}
@ -1580,6 +1581,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
var v interface{}
key := col.Name
fieldType := fieldValue.Type()
//fmt.Println("column name:", key, ", fieldType:", fieldType.String())
switch fieldType.Kind() {
case reflect.Complex64, reflect.Complex128:
@ -1730,19 +1733,22 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
}
case reflect.Ptr:
// !nashtsai! TODO merge duplicated codes above
typeStr := fieldType.String()
switch typeStr {
case "*string":
//typeStr := fieldType.String()
switch fieldType {
// case "*string":
case reflect.TypeOf(&c_EMPTY_STRING):
x := string(data)
fieldValue.Set(reflect.ValueOf(&x))
case "*bool":
// case "*bool":
case reflect.TypeOf(&c_BOOL_DEFAULT):
d := string(data)
v, err := strconv.ParseBool(d)
if err != nil {
return errors.New("arg " + key + " as bool: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&v))
case "*complex64":
// case "*complex64":
case reflect.TypeOf(&c_COMPLEX64_DEFAULT):
var x complex64
err := json.Unmarshal(data, &x)
if err != nil {
@ -1750,7 +1756,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return err
}
fieldValue.Set(reflect.ValueOf(&x))
case "*complex128":
// case "*complex128":
case reflect.TypeOf(&c_COMPLEX128_DEFAULT):
var x complex128
err := json.Unmarshal(data, &x)
if err != nil {
@ -1758,13 +1765,15 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return err
}
fieldValue.Set(reflect.ValueOf(&x))
case "*float64":
// case "*float64":
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
x, err := strconv.ParseFloat(string(data), 64)
if err != nil {
return errors.New("arg " + key + " as float64: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&x))
case "*float32":
// case "*float32":
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
var x float32
x1, err := strconv.ParseFloat(string(data), 32)
if err != nil {
@ -1772,7 +1781,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
}
x = float32(x1)
fieldValue.Set(reflect.ValueOf(&x))
case "*time.Time":
// case "*time.Time":
case reflect.TypeOf(&c_TIME_DEFAULT):
sdata := strings.TrimSpace(string(data))
var x time.Time
var err error
@ -1809,14 +1819,16 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
v = x
fieldValue.Set(reflect.ValueOf(&x))
case "*uint64":
// case "*uint64":
case reflect.TypeOf(&c_UINT64_DEFAULT):
var x uint64
x, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&x))
case "*uint":
// case "*uint":
case reflect.TypeOf(&c_UINT_DEFAULT):
var x uint
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -1824,7 +1836,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
}
x = uint(x1)
fieldValue.Set(reflect.ValueOf(&x))
case "*uint32":
// case "*uint32":
case reflect.TypeOf(&c_UINT32_DEFAULT):
var x uint32
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -1832,7 +1845,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
}
x = uint32(x1)
fieldValue.Set(reflect.ValueOf(&x))
case "*uint8":
// case "*uint8":
case reflect.TypeOf(&c_UINT8_DEFAULT):
var x uint8
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -1840,7 +1854,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
}
x = uint8(x1)
fieldValue.Set(reflect.ValueOf(&x))
case "*uint16":
// case "*uint16":
case reflect.TypeOf(&c_UINT16_DEFAULT):
var x uint16
x1, err := strconv.ParseUint(string(data), 10, 64)
if err != nil {
@ -1848,7 +1863,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
}
x = uint16(x1)
fieldValue.Set(reflect.ValueOf(&x))
case "*int64":
// case "*int64":
case reflect.TypeOf(&c_INT64_DEFAULT):
sdata := string(data)
var x int64
var err error
@ -1872,7 +1888,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&x))
case "*int":
// case "*int":
case reflect.TypeOf(&c_INT_DEFAULT):
sdata := string(data)
var x int
var x1 int64
@ -1900,7 +1917,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&x))
case "*int32":
// case "*int32":
case reflect.TypeOf(&c_INT32_DEFAULT):
sdata := string(data)
var x int32
var x1 int64
@ -1928,7 +1946,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&x))
case "*int8":
// case "*int8":
case reflect.TypeOf(&c_INT8_DEFAULT):
sdata := string(data)
var x int8
var x1 int64
@ -1956,7 +1975,8 @@ func (session *Session) bytes2Value(col *Column, fieldValue *reflect.Value, data
return errors.New("arg " + key + " as int: " + err.Error())
}
fieldValue.Set(reflect.ValueOf(&x))
case "*int16":
// case "*int16":
case reflect.TypeOf(&c_INT16_DEFAULT):
sdata := string(data)
var x int16
var x1 int64
@ -2494,6 +2514,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
var condition = ""
session.Statement.processIdParam()
st := session.Statement
defer session.Statement.Init()
if st.WhereStr != "" {
@ -2680,6 +2701,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
false, session.Statement.allUseBool, session.Statement.boolColumnMap)
var condition = ""
session.Statement.processIdParam()
if session.Statement.WhereStr != "" {
condition = session.Statement.WhereStr
if len(colNames) > 0 {

View File

@ -9,6 +9,27 @@ import (
"time"
)
// !nashtsai! treat following var as interal const values
var (
c_EMPTY_STRING = ""
c_BOOL_DEFAULT = false
c_COMPLEX64_DEFAULT = complex64(0)
c_COMPLEX128_DEFAULT = complex128(0)
c_FLOAT32_DEFAULT = float32(0)
c_FLOAT64_DEFAULT = float64(0)
c_INT64_DEFAULT = int64(0)
c_UINT64_DEFAULT = uint64(0)
c_INT32_DEFAULT = int32(0)
c_UINT32_DEFAULT = uint32(0)
c_INT16_DEFAULT = int16(0)
c_UINT16_DEFAULT = uint16(0)
c_INT8_DEFAULT = int8(0)
c_UINT8_DEFAULT = uint8(0)
c_INT_DEFAULT = int(0)
c_UINT_DEFAULT = uint(0)
c_TIME_DEFAULT time.Time = time.Unix(0, 0)
)
// statement save all the sql info for executing SQL
type Statement struct {
RefTable *Table
@ -16,6 +37,7 @@ type Statement struct {
Start int
LimitN int
WhereStr string
IdParam *PK
Params []interface{}
OrderStr string
JoinStr string
@ -256,7 +278,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
if fieldValue.IsNil() {
if includeNil {
args = append(args, nil)
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
}
continue
} else if !fieldValue.IsValid() {
@ -376,7 +398,7 @@ func buildConditions(engine *Engine, table *Table, bean interface{},
}
args = append(args, val)
colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
}
return colNames, args
@ -394,15 +416,40 @@ func (statement *Statement) TableName() string {
return ""
}
// Generate "Where id = ? " statment
func (statement *Statement) Id(id int64) *Statement {
if statement.WhereStr == "" {
statement.WhereStr = "(id)=?"
statement.Params = []interface{}{id}
} else {
statement.WhereStr = statement.WhereStr + " AND (id)=?"
statement.Params = append(statement.Params, id)
// Generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?"
func (statement *Statement) Id(id interface{}) *Statement {
idValue := reflect.ValueOf(id)
idType := reflect.TypeOf(idValue.Interface())
switch idType {
case reflect.TypeOf(&PK{}):
if pkPtr, ok := (id).(*PK); ok {
statement.IdParam = pkPtr
}
case reflect.TypeOf(PK{}):
if pk, ok := (id).(PK); ok {
statement.IdParam = &pk
}
default:
// TODO treat as int primitve for now, need to handle type check
statement.IdParam = &PK{id}
// !nashtsai! REVIEW although it will be user's mistake if called Id() twice with
// different value and Id should be PK's field name, however, at this stage probably
// can't tell which table is gonna be used
// if statement.WhereStr == "" {
// statement.WhereStr = "(id)=?"
// statement.Params = []interface{}{id}
// } else {
// // TODO what if id param has already passed
// statement.WhereStr = statement.WhereStr + " AND (id)=?"
// statement.Params = append(statement.Params, id)
// }
}
// !nashtsai! perhaps no need to validate pk values' type just let sql complaint happen
return statement
}
@ -559,14 +606,37 @@ func (statement *Statement) genColumnStr() string {
return strings.Join(colNames, ", ")
}
func (statement *Statement) genCreateSQL() string {
func (statement *Statement) genCreateTableSQL() string {
sql := "CREATE TABLE IF NOT EXISTS " + statement.Engine.Quote(statement.TableName()) + " ("
pkList := []string{}
for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[colName]
sql += col.String(statement.Engine.dialect)
if col.IsPrimaryKey {
pkList = append(pkList, col.Name)
}
}
statement.Engine.LogDebug("len:", len(pkList))
for _, colName := range statement.RefTable.ColumnsSeq {
col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey && len(pkList) == 1 {
sql += col.String(statement.Engine.dialect)
} else {
sql += col.stringNoPk(statement.Engine.dialect)
}
sql = strings.TrimSpace(sql)
sql += ", "
}
if len(pkList) > 1 {
sql += "PRIMARY KEY ( "
sql += strings.Join(pkList, ",")
sql += " ), "
}
sql = sql[:len(sql)-2] + ")"
if statement.Engine.dialect.SupportEngine() && statement.StoreEngine != "" {
sql += " ENGINE=" + statement.StoreEngine
@ -702,11 +772,14 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
if statement.IsDistinct {
distinct = "DISTINCT "
}
// !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern
a = fmt.Sprintf("SELECT %v%v FROM %v", distinct, columnStr,
statement.Engine.Quote(statement.TableName()))
if statement.JoinStr != "" {
a = fmt.Sprintf("%v %v", a, statement.JoinStr)
}
statement.processIdParam()
if statement.WhereStr != "" {
a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr)
if statement.ConditionStr != "" {
@ -732,3 +805,34 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) {
}
return
}
func (statement *Statement) processIdParam() {
if statement.IdParam != nil {
i := 0
colCnt := len(statement.RefTable.ColumnsSeq)
for _, elem := range *(statement.IdParam) {
for ; i < colCnt; i++ {
colName := statement.RefTable.ColumnsSeq[i]
col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), elem)
i++
break
}
}
}
// !nashtsai! REVIEW what if statement.IdParam has insufficient pk item? handle it
// as empty string for now, so this will result sql exec failed instead of unexpected
// false update/delete
for ; i < colCnt; i++ {
colName := statement.RefTable.ColumnsSeq[i]
col := statement.RefTable.Columns[colName]
if col.IsPrimaryKey {
statement.And(fmt.Sprintf("%v=?", col.Name), "")
}
}
}
}

View File

@ -155,25 +155,25 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
}
func ptrType2SQLType(t reflect.Type) (st SQLType, has bool) {
typeStr := t.String()
has = true
switch typeStr {
case "*string":
switch t {
case reflect.TypeOf(&c_EMPTY_STRING):
st = SQLType{Varchar, 255, 0}
case "*bool":
return
case reflect.TypeOf(&c_BOOL_DEFAULT):
st = SQLType{Bool, 0, 0}
case "*complex64", "*complex128":
case reflect.TypeOf(&c_COMPLEX64_DEFAULT), reflect.TypeOf(&c_COMPLEX128_DEFAULT):
st = SQLType{Varchar, 64, 0}
case "*float32":
case reflect.TypeOf(&c_FLOAT32_DEFAULT):
st = SQLType{Float, 0, 0}
case "*float64":
case reflect.TypeOf(&c_FLOAT64_DEFAULT):
st = SQLType{Double, 0, 0}
case "*int64", "*uint64":
case reflect.TypeOf(&c_INT64_DEFAULT), reflect.TypeOf(&c_UINT64_DEFAULT):
st = SQLType{BigInt, 0, 0}
case "*time.Time":
case reflect.TypeOf(&c_TIME_DEFAULT):
st = SQLType{DateTime, 0, 0}
case "*int", "*int16", "*int32", "*int8", "*uint", "*uint16", "*uint32", "*uint8":
case reflect.TypeOf(&c_INT_DEFAULT), reflect.TypeOf(&c_INT32_DEFAULT), reflect.TypeOf(&c_INT8_DEFAULT), reflect.TypeOf(&c_INT16_DEFAULT), reflect.TypeOf(&c_UINT_DEFAULT), reflect.TypeOf(&c_UINT32_DEFAULT), reflect.TypeOf(&c_UINT8_DEFAULT), reflect.TypeOf(&c_UINT16_DEFAULT):
st = SQLType{Int, 0, 0}
default:
has = false
@ -265,12 +265,29 @@ func (col *Column) String(d dialect) string {
if col.IsPrimaryKey {
sql += "PRIMARY KEY "
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
}
}
if col.IsAutoIncrement {
sql += d.AutoIncrStr() + " "
if col.Nullable {
sql += "NULL "
} else {
sql += "NOT NULL "
}
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
return sql
}
func (col *Column) stringNoPk(d dialect) string {
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
sql += d.SqlType(col) + " "
if col.Nullable {
sql += "NULL "
} else {