Added support for Find(*[]*Struct); added notnull;

This commit is contained in:
Lunny Xiao 2013-10-05 00:44:43 +08:00
parent 4b2425d3d3
commit c0d008e631
9 changed files with 249 additions and 131 deletions

View File

@ -271,7 +271,22 @@ func find(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
fmt.Println(users) for _, user := range users {
fmt.Println(user)
}
}
func find2(engine *Engine, t *testing.T) {
users := make([]*Userinfo, 0)
err := engine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
for _, user := range users {
fmt.Println(user)
}
} }
func findMap(engine *Engine, t *testing.T) { func findMap(engine *Engine, t *testing.T) {
@ -282,7 +297,22 @@ func findMap(engine *Engine, t *testing.T) {
t.Error(err) t.Error(err)
panic(err) panic(err)
} }
fmt.Println(users) for _, user := range users {
fmt.Println(user)
}
}
func findMap2(engine *Engine, t *testing.T) {
users := make(map[int64]*Userinfo)
err := engine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
for id, user := range users {
fmt.Println(id, user)
}
} }
func count(engine *Engine, t *testing.T) { func count(engine *Engine, t *testing.T) {
@ -683,7 +713,7 @@ func testCols(engine *Engine, t *testing.T) {
fmt.Println(users) fmt.Println(users)
tmpUsers := []tempUser{} tmpUsers := []tempUser{}
err = engine.Table("userinfo").Cols("id, username").Find(&tmpUsers) err = engine.NoCache().Table("userinfo").Cols("id, username").Find(&tmpUsers)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
panic(err) panic(err)
@ -1055,8 +1085,12 @@ func testAll(engine *Engine, t *testing.T) {
cascadeGet(engine, t) cascadeGet(engine, t)
fmt.Println("-------------- find --------------") fmt.Println("-------------- find --------------")
find(engine, t) find(engine, t)
fmt.Println("-------------- find2 --------------")
find2(engine, t)
fmt.Println("-------------- findMap --------------") fmt.Println("-------------- findMap --------------")
findMap(engine, t) findMap(engine, t)
fmt.Println("-------------- findMap2 --------------")
findMap2(engine, t)
fmt.Println("-------------- count --------------") fmt.Println("-------------- count --------------")
count(engine, t) count(engine, t)
fmt.Println("-------------- where --------------") fmt.Println("-------------- where --------------")

View File

@ -46,5 +46,4 @@ func doBenchCacheFind(engine *Engine, b *testing.B) {
b.Error(err) b.Error(err)
return return
} }
} }

147
engine.go
View File

@ -44,6 +44,7 @@ type Engine struct {
ShowSQL bool ShowSQL bool
ShowErr bool ShowErr bool
ShowDebug bool ShowDebug bool
ShowWarn bool
Pool IConnectPool Pool IConnectPool
Filters []Filter Filters []Filter
Logger io.Writer Logger io.Writer
@ -159,6 +160,12 @@ func (engine *Engine) LogDebug(contents ...interface{}) {
} }
} }
func (engine *Engine) LogWarn(contents ...interface{}) {
if engine.ShowWarn {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session { func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession() session := engine.NewSession()
session.IsAutoClose = true session.IsAutoClose = true
@ -286,7 +293,8 @@ func (engine *Engine) AutoMap(bean interface{}) *Table {
} }
func (engine *Engine) newTable() *Table { func (engine *Engine) newTable() *Table {
table := &Table{Indexes: map[string][]string{}, Uniques: map[string][]string{}} table := &Table{}
table.Indexes = make(map[string]*Index)
table.Columns = make(map[string]*Column) table.Columns = make(map[string]*Column)
table.ColumnsSeq = make([]string, 0) table.ColumnsSeq = make([]string, 0)
table.Cacher = engine.Cacher table.Cacher = engine.Cacher
@ -347,22 +355,42 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
col.IsCreated = true col.IsCreated = true
case k == "UPDATED": case k == "UPDATED":
col.IsUpdated = true col.IsUpdated = true
case strings.HasPrefix(k, "INDEX"): /*case strings.HasPrefix(k, "--"):
if k == "INDEX" { col.Comment = k[2:len(k)]*/
col.IndexName = "" case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
col.IndexType = SINGLEINDEX indexName := k[len("INDEX")+1 : len(k)-1]
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col)
col.Index = index
} else { } else {
col.IndexName = k[len("INDEX")+1 : len(k)-1] index := NewIndex(indexName, false)
col.IndexType = UNIONINDEX index.AddColumn(col)
table.AddIndex(index)
col.Index = index
} }
case strings.HasPrefix(k, "UNIQUE"): case k == "INDEX":
if k == "UNIQUE" { index := NewIndex(col.Name, false)
col.UniqueName = "" index.AddColumn(col)
col.UniqueType = SINGLEUNIQUE table.AddIndex(index)
col.Index = index
case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
indexName := k[len("UNIQUE")+1 : len(k)-1]
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col)
col.Index = index
} else { } else {
col.UniqueName = k[len("UNIQUE")+1 : len(k)-1] index := NewIndex(indexName, true)
col.UniqueType = UNIONUNIQUE index.AddColumn(col)
table.AddIndex(index)
col.Index = index
} }
case k == "UNIQUE":
index := NewIndex(col.Name, true)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
case k == "NOTNULL":
col.Nullable = false
case k == "NOT": case k == "NOT":
default: default:
if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") { if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") {
@ -395,60 +423,26 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
if col.SQLType.Name == "" { if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType) col.SQLType = Type2SQLType(fieldType)
} }
if col.Length == 0 { if col.Length == 0 {
col.Length = col.SQLType.DefaultLength col.Length = col.SQLType.DefaultLength
} }
if col.Length2 == 0 { if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2 col.Length2 = col.SQLType.DefaultLength2
} }
if col.Name == "" { if col.Name == "" {
col.Name = engine.Mapper.Obj2Table(t.Field(i).Name) col.Name = engine.Mapper.Obj2Table(t.Field(i).Name)
} }
if col.IsPrimaryKey {
table.PrimaryKey = col.Name
}
if col.IsCreated {
table.Created = col.Name
}
if col.IsUpdated {
table.Updated = col.Name
}
if col.IndexType == SINGLEINDEX {
col.IndexName = col.Name
table.Indexes[col.IndexName] = []string{col.Name}
} else if col.IndexType == UNIONINDEX {
if unionIdxes, ok := table.Indexes[col.IndexName]; ok {
table.Indexes[col.IndexName] = append(unionIdxes, col.Name)
} else {
table.Indexes[col.IndexName] = []string{col.Name}
}
}
if col.UniqueType == SINGLEUNIQUE {
col.UniqueName = col.Name
table.Uniques[col.UniqueName] = []string{col.Name}
} else if col.UniqueType == UNIONUNIQUE {
if unionUniques, ok := table.Uniques[col.UniqueName]; ok {
table.Uniques[col.UniqueName] = append(unionUniques, col.Name)
} else {
table.Uniques[col.UniqueName] = []string{col.Name}
}
}
} }
} else { } else {
sqlType := Type2SQLType(fieldType) sqlType := Type2SQLType(fieldType)
col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType, col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", NONEUNIQUE, "", sqlType.DefaultLength, sqlType.DefaultLength2, true, "", nil, false, false,
NONEINDEX, "", false, false, TWOSIDES, false, false} TWOSIDES, false, false, ""}
} }
if col.IsAutoIncrement { if col.IsAutoIncrement {
col.Nullable = false col.Nullable = false
} }
if col.IsPrimaryKey {
table.PrimaryKey = col.Name
}
table.AddColumn(col) table.AddColumn(col)
if col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id") { if col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id") {
@ -487,8 +481,8 @@ func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) {
engine.AutoMapType(t) engine.AutoMapType(t)
session := engine.NewSession() session := engine.NewSession()
defer session.Close() defer session.Close()
has, err := session.Get(bean) rows, err := session.Count(bean)
return !has, err return rows > 0, err
} }
// Is a table is exist // Is a table is exist
@ -587,41 +581,38 @@ func (engine *Engine) Sync(beans ...interface{}) error {
} }
} }
for idx, _ := range table.Indexes { for name, index := range table.Indexes {
session := engine.NewSession() session := engine.NewSession()
session.Statement.RefTable = table session.Statement.RefTable = table
defer session.Close() defer session.Close()
isExist, err := session.isIndexExist(table.Name, idx, false) if index.IsUnique {
if err != nil { isExist, err := session.isIndexExist(table.Name, name, true)
return err
}
if !isExist {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addIndex(table.Name, idx)
if err != nil { if err != nil {
return err return err
} }
} if !isExist {
} session := engine.NewSession()
session.Statement.RefTable = table
for uqe, _ := range table.Uniques { defer session.Close()
session := engine.NewSession() err = session.addUnique(table.Name, name)
session.Statement.RefTable = table if err != nil {
defer session.Close() return err
isExist, err := session.isIndexExist(table.Name, uqe, true) }
if err != nil { }
return err } else {
} isExist, err := session.isIndexExist(table.Name, name, false)
if !isExist {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addUnique(table.Name, uqe)
if err != nil { if err != nil {
return err return err
} }
if !isExist {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addIndex(table.Name, name)
if err != nil {
return err
}
}
} }
} }
} }

View File

@ -10,7 +10,7 @@ CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET
utf8 COLLATE utf8_general_ci; utf8 COLLATE utf8_general_ci;
*/ */
var showTestSql bool = false var showTestSql bool = true
func TestMyMysql(t *testing.T) { func TestMyMysql(t *testing.T) {
engine, err := NewEngine("mymysql", "xorm_test2/root/") engine, err := NewEngine("mymysql", "xorm_test2/root/")

View File

@ -23,6 +23,20 @@ func TestMysql(t *testing.T) {
testAll2(engine, t) testAll2(engine, t)
} }
func TestMysqlWithCache(t *testing.T) {
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close()
if err != nil {
t.Error(err)
return
}
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql
testAll(engine, t)
testAll2(engine, t)
}
func BenchmarkMysqlNoCache(t *testing.B) { func BenchmarkMysqlNoCache(t *testing.B) {
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8") engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close() defer engine.Close()

View File

@ -55,8 +55,12 @@ func TestPostgres2(t *testing.T) {
cascadeGet(engine, t) cascadeGet(engine, t)
fmt.Println("-------------- find --------------") fmt.Println("-------------- find --------------")
find(engine, t) find(engine, t)
fmt.Println("-------------- find2 --------------")
find2(engine, t)
fmt.Println("-------------- findMap --------------") fmt.Println("-------------- findMap --------------")
findMap(engine, t) findMap(engine, t)
fmt.Println("-------------- findMap2 --------------")
findMap2(engine, t)
fmt.Println("-------------- count --------------") fmt.Println("-------------- count --------------")
count(engine, t) count(engine, t)
fmt.Println("-------------- where --------------") fmt.Println("-------------- where --------------")

View File

@ -95,7 +95,8 @@ func (session *Session) Desc(colNames ...string) *Session {
if session.Statement.OrderStr != "" { if session.Statement.OrderStr != "" {
session.Statement.OrderStr += ", " session.Statement.OrderStr += ", "
} }
sql := strings.Join(colNames, session.Engine.Quote(" DESC, ")) newColNames := col2NewCols(colNames...)
sql := strings.Join(newColNames, session.Engine.Quote(" DESC, "))
session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC" session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC"
return session return session
} }
@ -104,7 +105,8 @@ func (session *Session) Asc(colNames ...string) *Session {
if session.Statement.OrderStr != "" { if session.Statement.OrderStr != "" {
session.Statement.OrderStr += ", " session.Statement.OrderStr += ", "
} }
sql := strings.Join(colNames, session.Engine.Quote(" ASC, ")) newColNames := col2NewCols(colNames...)
sql := strings.Join(newColNames, session.Engine.Quote(" ASC, "))
session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC" session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC"
return session return session
} }
@ -418,7 +420,8 @@ func (statement *Statement) convertIdSql(sql string) string {
if len(sqls) != 2 { if len(sqls) != 2 {
return "" return ""
} }
return fmt.Sprintf("SELECT %v FROM %v", statement.Engine.Quote(col.Name), sqls[1]) return fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()),
statement.Engine.Quote(col.Name), sqls[1])
} }
} }
return "" return ""
@ -552,6 +555,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
} }
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr)) sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
var idxes []int = make([]int, 0) var idxes []int = make([]int, 0)
var ides []interface{} = make([]interface{}, 0) var ides []interface{} = make([]interface{}, 0)
var temps []interface{} = make([]interface{}, len(ids)) var temps []interface{} = make([]interface{}, len(ids))
@ -571,7 +575,9 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
newSession := session.Engine.NewSession() newSession := session.Engine.NewSession()
defer newSession.Close() defer newSession.Close()
beans := reflect.New(sliceValue.Type()).Interface() slices := reflect.New(reflect.SliceOf(t))
beans := slices.Interface()
//beans := reflect.New(sliceValue.Type()).Interface()
err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans) err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans)
if err != nil { if err != nil {
return err return err
@ -589,7 +595,25 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
for j := 0; j < len(temps); j++ { for j := 0; j < len(temps); j++ {
bean := temps[j] bean := temps[j]
if bean != nil { if bean != nil {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean)))) if sliceValue.Kind() == reflect.Slice {
if t.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean)))
} else {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
}
} else if sliceValue.Kind() == reflect.Map {
var key int64
if table.PrimaryKey != "" {
key = ids[j]
} else {
key = int64(j)
}
if t.Kind() == reflect.Ptr {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean))
} else {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean)))
}
}
} else { } else {
session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j]) session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j])
cacher.DelBean(tableName, ids[j]) cacher.DelBean(tableName, ids[j])
@ -704,7 +728,18 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
sliceElementType := sliceValue.Type().Elem() sliceElementType := sliceValue.Type().Elem()
table := session.Engine.AutoMapType(sliceElementType) var table *Table
if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct {
table = session.Engine.AutoMapType(sliceElementType.Elem())
} else {
return errors.New("slice type")
}
} else if sliceElementType.Kind() == reflect.Struct {
table = session.Engine.AutoMapType(sliceElementType)
} else {
return errors.New("slice type")
}
session.Statement.RefTable = table session.Statement.RefTable = table
if len(condiBean) > 0 { if len(condiBean) > 0 {
@ -732,6 +767,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
if err != ErrCacheFailed { if err != ErrCacheFailed {
return err return err
} }
session.Engine.LogWarn("Cache Find Failed")
} }
resultsSlice, err := session.query(sql, args...) resultsSlice, err := session.query(sql, args...)
@ -740,13 +776,22 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} }
for i, results := range resultsSlice { for i, results := range resultsSlice {
newValue := reflect.New(sliceElementType) var newValue reflect.Value
if sliceElementType.Kind() == reflect.Ptr {
newValue = reflect.New(sliceElementType.Elem())
} else {
newValue = reflect.New(sliceElementType)
}
err := session.scanMapIntoStruct(newValue.Interface(), results) err := session.scanMapIntoStruct(newValue.Interface(), results)
if err != nil { if err != nil {
return err return err
} }
if sliceValue.Kind() == reflect.Slice { if sliceValue.Kind() == reflect.Slice {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface())))) if sliceElementType.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface())))
} else {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
}
} else if sliceValue.Kind() == reflect.Map { } else if sliceValue.Kind() == reflect.Map {
var key int64 var key int64
if table.PrimaryKey != "" { if table.PrimaryKey != "" {
@ -758,7 +803,11 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} else { } else {
key = int64(i) key = int64(i)
} }
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface()))) if sliceElementType.Kind() == reflect.Ptr {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface()))
} else {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface())))
}
} }
} }
return nil return nil
@ -852,7 +901,7 @@ func (session *Session) addIndex(tableName, idxName string) error {
defer session.Close() defer session.Close()
} }
//fmt.Println(idxName) //fmt.Println(idxName)
cols := session.Statement.RefTable.Indexes[idxName] cols := session.Statement.RefTable.Indexes[idxName].GenColsStr()
sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols) sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols)
_, err = session.exec(sql, args...) _, err = session.exec(sql, args...)
return err return err
@ -868,7 +917,7 @@ func (session *Session) addUnique(tableName, uqeName string) error {
defer session.Close() defer session.Close()
} }
//fmt.Println(uqeName, session.Statement.RefTable.Uniques) //fmt.Println(uqeName, session.Statement.RefTable.Uniques)
cols := session.Statement.RefTable.Uniques[uqeName] cols := session.Statement.RefTable.Indexes[uqeName].GenColsStr()
sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols) sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols)
_, err = session.exec(sql, args...) _, err = session.exec(sql, args...)
return err return err

View File

@ -198,18 +198,24 @@ func (statement *Statement) In(column string, args ...interface{}) {
} }
} }
func (statement *Statement) Cols(columns ...string) { func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0) newColumns := make([]string, 0)
for _, col := range columns { for _, col := range columns {
strings.Replace(col, "`", "", -1) strings.Replace(col, "`", "", -1)
strings.Replace(col, `"`, "", -1) strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",") ccols := strings.Split(col, ",")
for _, c := range ccols { for _, c := range ccols {
nc := strings.TrimSpace(c) newColumns = append(newColumns, strings.TrimSpace(c))
statement.columnMap[nc] = true
newColumns = append(newColumns, nc)
} }
} }
return newColumns
}
func (statement *Statement) Cols(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.columnMap[nc] = true
}
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
} }
@ -279,9 +285,9 @@ func (s *Statement) genIndexSQL() []string {
var sqls []string = make([]string, 0) var sqls []string = make([]string, 0)
tbName := s.TableName() tbName := s.TableName()
quote := s.Engine.Quote quote := s.Engine.Quote
for idxName, cols := range s.RefTable.Indexes { for idxName, index := range s.RefTable.Indexes {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(cols, quote(",")))) quote(tbName), quote(strings.Join(index.GenColsStr(), quote(","))))
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
return sqls return sqls
@ -291,11 +297,13 @@ func uniqueName(tableName, uqeName string) string {
return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
} }
func (statement *Statement) genUniqueSQL() []string { func (s *Statement) genUniqueSQL() []string {
var sqls []string = make([]string, 0) var sqls []string = make([]string, 0)
for indexName, cols := range statement.RefTable.Uniques { tbName := s.TableName()
sql := fmt.Sprintf("CREATE UNIQUE INDEX `%v` ON %v (%v);", uniqueName(statement.TableName(), indexName), quote := s.Engine.Quote
statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(strings.Join(cols, statement.Engine.Quote(",")))) for idxName, unique := range s.RefTable.Indexes {
sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)),
quote(tbName), quote(strings.Join(unique.GenColsStr(), quote(","))))
sqls = append(sqls, sql) sqls = append(sqls, sql)
} }
return sqls return sqls
@ -303,15 +311,14 @@ func (statement *Statement) genUniqueSQL() []string {
func (s *Statement) genDelIndexSQL() []string { func (s *Statement) genDelIndexSQL() []string {
var sqls []string = make([]string, 0) var sqls []string = make([]string, 0)
for indexName, _ := range s.RefTable.Uniques { for idxName, index := range s.RefTable.Indexes {
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(uniqueName(s.TableName(), indexName))) var rIdxName string
if s.Engine.Dialect.IndexOnTable() { if index.IsUnique {
sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName())) rIdxName = uniqueName(s.TableName(), idxName)
} else {
rIdxName = indexName(s.TableName(), idxName)
} }
sqls = append(sqls, sql) sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName))
}
for indexName, _ := range s.RefTable.Indexes {
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(uniqueName(s.TableName(), indexName)))
if s.Engine.Dialect.IndexOnTable() { if s.Engine.Dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName())) sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName()))
} }
@ -369,7 +376,11 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
colNames, args := buildConditions(statement.Engine, table, bean) colNames, args := buildConditions(statement.Engine, table, bean)
statement.ConditionStr = strings.Join(colNames, " and ") statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args statement.BeanArgs = args
return statement.genSelectSql(fmt.Sprintf("count(*) as %v", statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...) var id string = "*"
if table.PrimaryKey != "" {
id = statement.Engine.Quote(table.PrimaryKey)
}
return statement.genSelectSql(fmt.Sprintf("count(%v) as %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
} }
func (statement Statement) genSelectSql(columnStr string) (a string) { func (statement Statement) genSelectSql(columnStr string) (a string) {

View File

@ -149,24 +149,26 @@ const (
ONLYFROMDB ONLYFROMDB
) )
const (
NONEINDEX = iota
SINGLEINDEX
UNIONINDEX
)
const (
NONEUNIQUE = iota
SINGLEUNIQUE
UNIONUNIQUE
)
type Index struct { type Index struct {
Name string Name string
IsUnique bool IsUnique bool
Cols []*Column Cols []*Column
} }
func (index *Index) AddColumn(cols ...*Column) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
func (index *Index) GenColsStr() []string {
names := make([]string, len(index.Cols))
for idx, col := range index.Cols {
names[idx] = col.Name
}
return names
}
func NewIndex(name string, isUnique bool) *Index { func NewIndex(name string, isUnique bool) *Index {
return &Index{name, isUnique, make([]*Column, 0)} return &Index{name, isUnique, make([]*Column, 0)}
} }
@ -179,15 +181,13 @@ type Column struct {
Length2 int Length2 int
Nullable bool Nullable bool
Default string Default string
UniqueType int Index *Index
UniqueName string
IndexType int
IndexName string
IsPrimaryKey bool IsPrimaryKey bool
IsAutoIncrement bool IsAutoIncrement bool
MapType int MapType int
IsCreated bool IsCreated bool
IsUpdated bool IsUpdated bool
Comment string
} }
func (col *Column) String(engine *Engine) string { func (col *Column) String(engine *Engine) string {
@ -212,6 +212,10 @@ func (col *Column) String(engine *Engine) string {
if col.Default != "" { if col.Default != "" {
sql += "DEFAULT " + col.Default + " " sql += "DEFAULT " + col.Default + " "
} }
if col.Comment != "" {
sql += "COMMENT '" + col.Comment + "' "
}
return sql return sql
} }
@ -236,8 +240,7 @@ type Table struct {
Type reflect.Type Type reflect.Type
ColumnsSeq []string ColumnsSeq []string
Columns map[string]*Column Columns map[string]*Column
Indexes map[string][]string Indexes map[string]*Index
Uniques map[string][]string
PrimaryKey string PrimaryKey string
Created string Created string
Updated string Updated string
@ -251,6 +254,19 @@ func (table *Table) PKColumn() *Column {
func (table *Table) AddColumn(col *Column) { func (table *Table) AddColumn(col *Column) {
table.ColumnsSeq = append(table.ColumnsSeq, col.Name) table.ColumnsSeq = append(table.ColumnsSeq, col.Name)
table.Columns[col.Name] = col table.Columns[col.Name] = col
if col.IsPrimaryKey {
table.PrimaryKey = col.Name
}
if col.IsCreated {
table.Created = col.Name
}
if col.IsUpdated {
table.Updated = col.Name
}
}
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
} }
func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) { func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {