Added support for Find(*[]*Struct); added notnull;

This commit is contained in:
Lunny Xiao 2013-10-05 00:44:43 +08:00
parent 4b2425d3d3
commit c0d008e631
9 changed files with 249 additions and 131 deletions

View File

@ -271,7 +271,22 @@ func find(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
fmt.Println(users)
for _, user := range users {
fmt.Println(user)
}
}
func find2(engine *Engine, t *testing.T) {
users := make([]*Userinfo, 0)
err := engine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
for _, user := range users {
fmt.Println(user)
}
}
func findMap(engine *Engine, t *testing.T) {
@ -282,7 +297,22 @@ func findMap(engine *Engine, t *testing.T) {
t.Error(err)
panic(err)
}
fmt.Println(users)
for _, user := range users {
fmt.Println(user)
}
}
func findMap2(engine *Engine, t *testing.T) {
users := make(map[int64]*Userinfo)
err := engine.Find(&users)
if err != nil {
t.Error(err)
panic(err)
}
for id, user := range users {
fmt.Println(id, user)
}
}
func count(engine *Engine, t *testing.T) {
@ -683,7 +713,7 @@ func testCols(engine *Engine, t *testing.T) {
fmt.Println(users)
tmpUsers := []tempUser{}
err = engine.Table("userinfo").Cols("id, username").Find(&tmpUsers)
err = engine.NoCache().Table("userinfo").Cols("id, username").Find(&tmpUsers)
if err != nil {
t.Error(err)
panic(err)
@ -1055,8 +1085,12 @@ func testAll(engine *Engine, t *testing.T) {
cascadeGet(engine, t)
fmt.Println("-------------- find --------------")
find(engine, t)
fmt.Println("-------------- find2 --------------")
find2(engine, t)
fmt.Println("-------------- findMap --------------")
findMap(engine, t)
fmt.Println("-------------- findMap2 --------------")
findMap2(engine, t)
fmt.Println("-------------- count --------------")
count(engine, t)
fmt.Println("-------------- where --------------")

View File

@ -46,5 +46,4 @@ func doBenchCacheFind(engine *Engine, b *testing.B) {
b.Error(err)
return
}
}

117
engine.go
View File

@ -44,6 +44,7 @@ type Engine struct {
ShowSQL bool
ShowErr bool
ShowDebug bool
ShowWarn bool
Pool IConnectPool
Filters []Filter
Logger io.Writer
@ -159,6 +160,12 @@ func (engine *Engine) LogDebug(contents ...interface{}) {
}
}
func (engine *Engine) LogWarn(contents ...interface{}) {
if engine.ShowWarn {
io.WriteString(engine.Logger, fmt.Sprintln(contents...))
}
}
func (engine *Engine) Sql(querystring string, args ...interface{}) *Session {
session := engine.NewSession()
session.IsAutoClose = true
@ -286,7 +293,8 @@ func (engine *Engine) AutoMap(bean interface{}) *Table {
}
func (engine *Engine) newTable() *Table {
table := &Table{Indexes: map[string][]string{}, Uniques: map[string][]string{}}
table := &Table{}
table.Indexes = make(map[string]*Index)
table.Columns = make(map[string]*Column)
table.ColumnsSeq = make([]string, 0)
table.Cacher = engine.Cacher
@ -347,22 +355,42 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
col.IsCreated = true
case k == "UPDATED":
col.IsUpdated = true
case strings.HasPrefix(k, "INDEX"):
if k == "INDEX" {
col.IndexName = ""
col.IndexType = SINGLEINDEX
/*case strings.HasPrefix(k, "--"):
col.Comment = k[2:len(k)]*/
case strings.HasPrefix(k, "INDEX(") && strings.HasSuffix(k, ")"):
indexName := k[len("INDEX")+1 : len(k)-1]
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col)
col.Index = index
} else {
col.IndexName = k[len("INDEX")+1 : len(k)-1]
col.IndexType = UNIONINDEX
index := NewIndex(indexName, false)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
}
case strings.HasPrefix(k, "UNIQUE"):
if k == "UNIQUE" {
col.UniqueName = ""
col.UniqueType = SINGLEUNIQUE
case k == "INDEX":
index := NewIndex(col.Name, false)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
case strings.HasPrefix(k, "UNIQUE(") && strings.HasSuffix(k, ")"):
indexName := k[len("UNIQUE")+1 : len(k)-1]
if index, ok := table.Indexes[indexName]; ok {
index.AddColumn(col)
col.Index = index
} else {
col.UniqueName = k[len("UNIQUE")+1 : len(k)-1]
col.UniqueType = UNIONUNIQUE
index := NewIndex(indexName, true)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
}
case k == "UNIQUE":
index := NewIndex(col.Name, true)
index.AddColumn(col)
table.AddIndex(index)
col.Index = index
case k == "NOTNULL":
col.Nullable = false
case k == "NOT":
default:
if strings.HasPrefix(k, "'") && strings.HasSuffix(k, "'") {
@ -395,60 +423,26 @@ func (engine *Engine) MapType(t reflect.Type) *Table {
if col.SQLType.Name == "" {
col.SQLType = Type2SQLType(fieldType)
}
if col.Length == 0 {
col.Length = col.SQLType.DefaultLength
}
if col.Length2 == 0 {
col.Length2 = col.SQLType.DefaultLength2
}
if col.Name == "" {
col.Name = engine.Mapper.Obj2Table(t.Field(i).Name)
}
if col.IsPrimaryKey {
table.PrimaryKey = col.Name
}
if col.IsCreated {
table.Created = col.Name
}
if col.IsUpdated {
table.Updated = col.Name
}
if col.IndexType == SINGLEINDEX {
col.IndexName = col.Name
table.Indexes[col.IndexName] = []string{col.Name}
} else if col.IndexType == UNIONINDEX {
if unionIdxes, ok := table.Indexes[col.IndexName]; ok {
table.Indexes[col.IndexName] = append(unionIdxes, col.Name)
} else {
table.Indexes[col.IndexName] = []string{col.Name}
}
}
if col.UniqueType == SINGLEUNIQUE {
col.UniqueName = col.Name
table.Uniques[col.UniqueName] = []string{col.Name}
} else if col.UniqueType == UNIONUNIQUE {
if unionUniques, ok := table.Uniques[col.UniqueName]; ok {
table.Uniques[col.UniqueName] = append(unionUniques, col.Name)
} else {
table.Uniques[col.UniqueName] = []string{col.Name}
}
}
}
} else {
sqlType := Type2SQLType(fieldType)
col = &Column{engine.Mapper.Obj2Table(t.Field(i).Name), t.Field(i).Name, sqlType,
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", NONEUNIQUE, "",
NONEINDEX, "", false, false, TWOSIDES, false, false}
sqlType.DefaultLength, sqlType.DefaultLength2, true, "", nil, false, false,
TWOSIDES, false, false, ""}
}
if col.IsAutoIncrement {
col.Nullable = false
}
if col.IsPrimaryKey {
table.PrimaryKey = col.Name
}
table.AddColumn(col)
if col.FieldName == "Id" || strings.HasSuffix(col.FieldName, ".Id") {
@ -487,8 +481,8 @@ func (engine *Engine) IsTableEmpty(bean interface{}) (bool, error) {
engine.AutoMapType(t)
session := engine.NewSession()
defer session.Close()
has, err := session.Get(bean)
return !has, err
rows, err := session.Count(bean)
return rows > 0, err
}
// Is a table is exist
@ -587,11 +581,12 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
}
for idx, _ := range table.Indexes {
for name, index := range table.Indexes {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
isExist, err := session.isIndexExist(table.Name, idx, false)
if index.IsUnique {
isExist, err := session.isIndexExist(table.Name, name, true)
if err != nil {
return err
}
@ -599,18 +594,13 @@ func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addIndex(table.Name, idx)
err = session.addUnique(table.Name, name)
if err != nil {
return err
}
}
}
for uqe, _ := range table.Uniques {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
isExist, err := session.isIndexExist(table.Name, uqe, true)
} else {
isExist, err := session.isIndexExist(table.Name, name, false)
if err != nil {
return err
}
@ -618,7 +608,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
session := engine.NewSession()
session.Statement.RefTable = table
defer session.Close()
err = session.addUnique(table.Name, uqe)
err = session.addIndex(table.Name, name)
if err != nil {
return err
}
@ -626,6 +616,7 @@ func (engine *Engine) Sync(beans ...interface{}) error {
}
}
}
}
return nil
}

View File

@ -10,7 +10,7 @@ CREATE DATABASE IF NOT EXISTS xorm_test CHARACTER SET
utf8 COLLATE utf8_general_ci;
*/
var showTestSql bool = false
var showTestSql bool = true
func TestMyMysql(t *testing.T) {
engine, err := NewEngine("mymysql", "xorm_test2/root/")

View File

@ -23,6 +23,20 @@ func TestMysql(t *testing.T) {
testAll2(engine, t)
}
func TestMysqlWithCache(t *testing.T) {
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close()
if err != nil {
t.Error(err)
return
}
engine.SetDefaultCacher(NewLRUCacher(NewMemoryStore(), 1000))
engine.ShowSQL = showTestSql
testAll(engine, t)
testAll2(engine, t)
}
func BenchmarkMysqlNoCache(t *testing.B) {
engine, err := NewEngine("mysql", "root:@/xorm_test?charset=utf8")
defer engine.Close()

View File

@ -55,8 +55,12 @@ func TestPostgres2(t *testing.T) {
cascadeGet(engine, t)
fmt.Println("-------------- find --------------")
find(engine, t)
fmt.Println("-------------- find2 --------------")
find2(engine, t)
fmt.Println("-------------- findMap --------------")
findMap(engine, t)
fmt.Println("-------------- findMap2 --------------")
findMap2(engine, t)
fmt.Println("-------------- count --------------")
count(engine, t)
fmt.Println("-------------- where --------------")

View File

@ -95,7 +95,8 @@ func (session *Session) Desc(colNames ...string) *Session {
if session.Statement.OrderStr != "" {
session.Statement.OrderStr += ", "
}
sql := strings.Join(colNames, session.Engine.Quote(" DESC, "))
newColNames := col2NewCols(colNames...)
sql := strings.Join(newColNames, session.Engine.Quote(" DESC, "))
session.Statement.OrderStr += session.Engine.Quote(sql) + " DESC"
return session
}
@ -104,7 +105,8 @@ func (session *Session) Asc(colNames ...string) *Session {
if session.Statement.OrderStr != "" {
session.Statement.OrderStr += ", "
}
sql := strings.Join(colNames, session.Engine.Quote(" ASC, "))
newColNames := col2NewCols(colNames...)
sql := strings.Join(newColNames, session.Engine.Quote(" ASC, "))
session.Statement.OrderStr += session.Engine.Quote(sql) + " ASC"
return session
}
@ -418,7 +420,8 @@ func (statement *Statement) convertIdSql(sql string) string {
if len(sqls) != 2 {
return ""
}
return fmt.Sprintf("SELECT %v FROM %v", statement.Engine.Quote(col.Name), sqls[1])
return fmt.Sprintf("SELECT %v.%v FROM %v", statement.Engine.Quote(statement.TableName()),
statement.Engine.Quote(col.Name), sqls[1])
}
}
return ""
@ -552,6 +555,7 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
var idxes []int = make([]int, 0)
var ides []interface{} = make([]interface{}, 0)
var temps []interface{} = make([]interface{}, len(ids))
@ -571,7 +575,9 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
newSession := session.Engine.NewSession()
defer newSession.Close()
beans := reflect.New(sliceValue.Type()).Interface()
slices := reflect.New(reflect.SliceOf(t))
beans := slices.Interface()
//beans := reflect.New(sliceValue.Type()).Interface()
err = newSession.In("(id)", ides...).OrderBy(session.Statement.OrderStr).NoCache().Find(beans)
if err != nil {
return err
@ -589,7 +595,25 @@ func (session *Session) cacheFind(t reflect.Type, sql string, rowsSlicePtr inter
for j := 0; j < len(temps); j++ {
bean := temps[j]
if bean != nil {
if sliceValue.Kind() == reflect.Slice {
if t.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(bean)))
} else {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(bean))))
}
} else if sliceValue.Kind() == reflect.Map {
var key int64
if table.PrimaryKey != "" {
key = ids[j]
} else {
key = int64(j)
}
if t.Kind() == reflect.Ptr {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(bean))
} else {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(bean)))
}
}
} else {
session.Engine.LogDebug("[xorm:cacheFind] cache delete:", tableName, ides[j])
cacher.DelBean(tableName, ids[j])
@ -704,7 +728,18 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
sliceElementType := sliceValue.Type().Elem()
table := session.Engine.AutoMapType(sliceElementType)
var table *Table
if sliceElementType.Kind() == reflect.Ptr {
if sliceElementType.Elem().Kind() == reflect.Struct {
table = session.Engine.AutoMapType(sliceElementType.Elem())
} else {
return errors.New("slice type")
}
} else if sliceElementType.Kind() == reflect.Struct {
table = session.Engine.AutoMapType(sliceElementType)
} else {
return errors.New("slice type")
}
session.Statement.RefTable = table
if len(condiBean) > 0 {
@ -732,6 +767,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
if err != ErrCacheFailed {
return err
}
session.Engine.LogWarn("Cache Find Failed")
}
resultsSlice, err := session.query(sql, args...)
@ -740,13 +776,22 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
}
for i, results := range resultsSlice {
newValue := reflect.New(sliceElementType)
var newValue reflect.Value
if sliceElementType.Kind() == reflect.Ptr {
newValue = reflect.New(sliceElementType.Elem())
} else {
newValue = reflect.New(sliceElementType)
}
err := session.scanMapIntoStruct(newValue.Interface(), results)
if err != nil {
return err
}
if sliceValue.Kind() == reflect.Slice {
if sliceElementType.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(newValue.Interface())))
} else {
sliceValue.Set(reflect.Append(sliceValue, reflect.Indirect(reflect.ValueOf(newValue.Interface()))))
}
} else if sliceValue.Kind() == reflect.Map {
var key int64
if table.PrimaryKey != "" {
@ -758,9 +803,13 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{})
} else {
key = int64(i)
}
if sliceElementType.Kind() == reflect.Ptr {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(newValue.Interface()))
} else {
sliceValue.SetMapIndex(reflect.ValueOf(key), reflect.Indirect(reflect.ValueOf(newValue.Interface())))
}
}
}
return nil
}
@ -852,7 +901,7 @@ func (session *Session) addIndex(tableName, idxName string) error {
defer session.Close()
}
//fmt.Println(idxName)
cols := session.Statement.RefTable.Indexes[idxName]
cols := session.Statement.RefTable.Indexes[idxName].GenColsStr()
sql, args := session.Statement.genAddIndexStr(indexName(tableName, idxName), cols)
_, err = session.exec(sql, args...)
return err
@ -868,7 +917,7 @@ func (session *Session) addUnique(tableName, uqeName string) error {
defer session.Close()
}
//fmt.Println(uqeName, session.Statement.RefTable.Uniques)
cols := session.Statement.RefTable.Uniques[uqeName]
cols := session.Statement.RefTable.Indexes[uqeName].GenColsStr()
sql, args := session.Statement.genAddUniqueStr(uniqueName(tableName, uqeName), cols)
_, err = session.exec(sql, args...)
return err

View File

@ -198,18 +198,24 @@ func (statement *Statement) In(column string, args ...interface{}) {
}
}
func (statement *Statement) Cols(columns ...string) {
func col2NewCols(columns ...string) []string {
newColumns := make([]string, 0)
for _, col := range columns {
strings.Replace(col, "`", "", -1)
strings.Replace(col, `"`, "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
nc := strings.TrimSpace(c)
statement.columnMap[nc] = true
newColumns = append(newColumns, nc)
newColumns = append(newColumns, strings.TrimSpace(c))
}
}
return newColumns
}
func (statement *Statement) Cols(columns ...string) {
newColumns := col2NewCols(columns...)
for _, nc := range newColumns {
statement.columnMap[nc] = true
}
statement.ColumnStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
}
@ -279,9 +285,9 @@ func (s *Statement) genIndexSQL() []string {
var sqls []string = make([]string, 0)
tbName := s.TableName()
quote := s.Engine.Quote
for idxName, cols := range s.RefTable.Indexes {
for idxName, index := range s.RefTable.Indexes {
sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)),
quote(tbName), quote(strings.Join(cols, quote(","))))
quote(tbName), quote(strings.Join(index.GenColsStr(), quote(","))))
sqls = append(sqls, sql)
}
return sqls
@ -291,11 +297,13 @@ func uniqueName(tableName, uqeName string) string {
return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
}
func (statement *Statement) genUniqueSQL() []string {
func (s *Statement) genUniqueSQL() []string {
var sqls []string = make([]string, 0)
for indexName, cols := range statement.RefTable.Uniques {
sql := fmt.Sprintf("CREATE UNIQUE INDEX `%v` ON %v (%v);", uniqueName(statement.TableName(), indexName),
statement.Engine.Quote(statement.TableName()), statement.Engine.Quote(strings.Join(cols, statement.Engine.Quote(","))))
tbName := s.TableName()
quote := s.Engine.Quote
for idxName, unique := range s.RefTable.Indexes {
sql := fmt.Sprintf("CREATE UNIQUE INDEX %v ON %v (%v);", quote(uniqueName(tbName, idxName)),
quote(tbName), quote(strings.Join(unique.GenColsStr(), quote(","))))
sqls = append(sqls, sql)
}
return sqls
@ -303,15 +311,14 @@ func (statement *Statement) genUniqueSQL() []string {
func (s *Statement) genDelIndexSQL() []string {
var sqls []string = make([]string, 0)
for indexName, _ := range s.RefTable.Uniques {
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(uniqueName(s.TableName(), indexName)))
if s.Engine.Dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName()))
for idxName, index := range s.RefTable.Indexes {
var rIdxName string
if index.IsUnique {
rIdxName = uniqueName(s.TableName(), idxName)
} else {
rIdxName = indexName(s.TableName(), idxName)
}
sqls = append(sqls, sql)
}
for indexName, _ := range s.RefTable.Indexes {
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(uniqueName(s.TableName(), indexName)))
sql := fmt.Sprintf("DROP INDEX %v", s.Engine.Quote(rIdxName))
if s.Engine.Dialect.IndexOnTable() {
sql += fmt.Sprintf(" ON %v", s.Engine.Quote(s.TableName()))
}
@ -369,7 +376,11 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{})
colNames, args := buildConditions(statement.Engine, table, bean)
statement.ConditionStr = strings.Join(colNames, " and ")
statement.BeanArgs = args
return statement.genSelectSql(fmt.Sprintf("count(*) as %v", statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
var id string = "*"
if table.PrimaryKey != "" {
id = statement.Engine.Quote(table.PrimaryKey)
}
return statement.genSelectSql(fmt.Sprintf("count(%v) as %v", id, statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)
}
func (statement Statement) genSelectSql(columnStr string) (a string) {

View File

@ -149,24 +149,26 @@ const (
ONLYFROMDB
)
const (
NONEINDEX = iota
SINGLEINDEX
UNIONINDEX
)
const (
NONEUNIQUE = iota
SINGLEUNIQUE
UNIONUNIQUE
)
type Index struct {
Name string
IsUnique bool
Cols []*Column
}
func (index *Index) AddColumn(cols ...*Column) {
for _, col := range cols {
index.Cols = append(index.Cols, col)
}
}
func (index *Index) GenColsStr() []string {
names := make([]string, len(index.Cols))
for idx, col := range index.Cols {
names[idx] = col.Name
}
return names
}
func NewIndex(name string, isUnique bool) *Index {
return &Index{name, isUnique, make([]*Column, 0)}
}
@ -179,15 +181,13 @@ type Column struct {
Length2 int
Nullable bool
Default string
UniqueType int
UniqueName string
IndexType int
IndexName string
Index *Index
IsPrimaryKey bool
IsAutoIncrement bool
MapType int
IsCreated bool
IsUpdated bool
Comment string
}
func (col *Column) String(engine *Engine) string {
@ -212,6 +212,10 @@ func (col *Column) String(engine *Engine) string {
if col.Default != "" {
sql += "DEFAULT " + col.Default + " "
}
if col.Comment != "" {
sql += "COMMENT '" + col.Comment + "' "
}
return sql
}
@ -236,8 +240,7 @@ type Table struct {
Type reflect.Type
ColumnsSeq []string
Columns map[string]*Column
Indexes map[string][]string
Uniques map[string][]string
Indexes map[string]*Index
PrimaryKey string
Created string
Updated string
@ -251,6 +254,19 @@ func (table *Table) PKColumn() *Column {
func (table *Table) AddColumn(col *Column) {
table.ColumnsSeq = append(table.ColumnsSeq, col.Name)
table.Columns[col.Name] = col
if col.IsPrimaryKey {
table.PrimaryKey = col.Name
}
if col.IsCreated {
table.Created = col.Name
}
if col.IsUpdated {
table.Updated = col.Name
}
}
func (table *Table) AddIndex(index *Index) {
table.Indexes[index.Name] = index
}
func (table *Table) genCols(session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {