diff --git a/engine.go b/engine.go index 9aaa8a32..d85fe90e 100644 --- a/engine.go +++ b/engine.go @@ -96,23 +96,11 @@ func (engine *Engine) NoCache() *Session { } func (engine *Engine) MapCacher(bean interface{}, cacher Cacher) { - t := Type(bean) + t := rType(bean) engine.AutoMapType(t) engine.Tables[t].Cacher = cacher } -func Type(bean interface{}) reflect.Type { - sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - return reflect.TypeOf(sliceValue.Interface()) -} - -func StructName(v reflect.Type) string { - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - return v.Name() -} - func (e *Engine) OpenDB() (*sql.DB, error) { return sql.Open(e.DriverName, e.DataSourceName) } @@ -274,7 +262,7 @@ func (engine *Engine) AutoMapType(t reflect.Type) *Table { } func (engine *Engine) AutoMap(bean interface{}) *Table { - t := Type(bean) + t := rType(bean) return engine.AutoMapType(t) } @@ -465,7 +453,7 @@ func (engine *Engine) Map(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() for _, bean := range beans { - t := Type(bean) + t := rType(bean) engine.Tables[t] = engine.MapType(t) } return @@ -473,7 +461,7 @@ func (engine *Engine) Map(beans ...interface{}) (e error) { // is a table has func (engine *Engine) IsEmptyTable(bean interface{}) (bool, error) { - t := Type(bean) + t := rType(bean) if t.Kind() != reflect.Struct { return false, errors.New("bean should be a struct or struct's point") } @@ -485,7 +473,7 @@ func (engine *Engine) IsEmptyTable(bean interface{}) (bool, error) { } func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { - t := Type(bean) + t := rType(bean) if t.Kind() != reflect.Struct { return false, errors.New("bean should be a struct or struct's point") } @@ -497,7 +485,7 @@ func (engine *Engine) IsTableExist(bean interface{}) (bool, error) { } func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { - t := Type(bean) + t := rType(bean) if t.Kind() != reflect.Struct { return errors.New("error params") } @@ -511,7 +499,7 @@ func (engine *Engine) ClearCacheBean(bean interface{}, id int64) error { func (engine *Engine) ClearCache(beans ...interface{}) error { for _, bean := range beans { - t := Type(bean) + t := rType(bean) if t.Kind() != reflect.Struct { return errors.New("error params") } @@ -622,7 +610,7 @@ func (engine *Engine) UnMap(beans ...interface{}) (e error) { engine.mutex.Lock() defer engine.mutex.Unlock() for _, bean := range beans { - t := Type(bean) + t := rType(bean) if _, ok := engine.Tables[t]; ok { delete(engine.Tables, t) } diff --git a/helpers.go b/helpers.go index eb24e341..383c66b5 100644 --- a/helpers.go +++ b/helpers.go @@ -1,25 +1,46 @@ package xorm import ( + "reflect" "strings" ) -func IndexNoCase(s, sep string) int { +func indexNoCase(s, sep string) int { return strings.Index(strings.ToLower(s), strings.ToLower(sep)) } -func SplitNoCase(s, sep string) []string { - idx := IndexNoCase(s, sep) +func splitNoCase(s, sep string) []string { + idx := indexNoCase(s, sep) if idx < 0 { return []string{s} } return strings.Split(s, s[idx:idx+len(sep)]) } -func SplitNNoCase(s, sep string, n int) []string { - idx := IndexNoCase(s, sep) +func splitNNoCase(s, sep string, n int) []string { + idx := indexNoCase(s, sep) if idx < 0 { return []string{s} } return strings.SplitN(s, s[idx:idx+len(sep)], n) } + +func makeArray(elem string, count int) []string { + res := make([]string, count) + for i := 0; i < count; i++ { + res[i] = elem + } + return res +} + +func rType(bean interface{}) reflect.Type { + sliceValue := reflect.Indirect(reflect.ValueOf(bean)) + return reflect.TypeOf(sliceValue.Interface()) +} + +func structName(v reflect.Type) string { + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v.Name() +} diff --git a/mapper.go b/mapper.go index 8a781971..224658bb 100644 --- a/mapper.go +++ b/mapper.go @@ -44,7 +44,7 @@ func snakeCasedName(name string) string { return string(newstr) } -func Pascal2Sql(s string) (d string) { +func pascal2Sql(s string) (d string) { d = "" lastIdx := 0 for i := 0; i < len(s); i++ { diff --git a/session.go b/session.go index f44f0fc6..ac9c82ff 100644 --- a/session.go +++ b/session.go @@ -201,7 +201,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b return errors.New("Expected a pointer to a struct") } - table := session.Engine.Tables[Type(obj)] + table := session.Engine.Tables[rType(obj)] for key, data := range objMap { if _, ok := table.Columns[key]; !ok { @@ -709,7 +709,7 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.Statement.RefTable = table if len(condiBean) > 0 { - colNames, args := BuildConditions(session.Engine, table, condiBean[0]) + colNames, args := buildConditions(session.Engine, table, condiBean[0]) session.Statement.ConditionStr = strings.Join(colNames, " and ") session.Statement.BeanArgs = args } @@ -1038,7 +1038,7 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } bean := sliceValue.Index(0).Interface() - sliceElementType := Type(bean) + sliceElementType := rType(bean) table := session.Engine.AutoMapType(sliceElementType) session.Statement.RefTable = table @@ -1520,12 +1520,12 @@ func (session *Session) cacheUpdate(sql string, args ...interface{}) error { for _, id := range ids { if bean := cacher.GetBean(tableName, id); bean != nil { - sqls := SplitNNoCase(sql, "where", 2) + sqls := splitNNoCase(sql, "where", 2) if len(sqls) != 2 { return ErrCacheFailed } - sqls = SplitNNoCase(sqls[0], "set", 2) + sqls = splitNNoCase(sqls[0], "set", 2) if len(sqls) != 2 { return ErrCacheFailed } @@ -1567,7 +1567,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 defer session.Close() } - t := Type(bean) + t := rType(bean) var colNames []string var args []interface{} @@ -1578,7 +1578,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 session.Statement.RefTable = table if session.Statement.ColumnStr == "" { - colNames, args = BuildConditions(session.Engine, table, bean) + colNames, args = buildConditions(session.Engine, table, bean) } else { colNames, args, err = table.GenCols(session, bean, true, true) if err != nil { @@ -1614,7 +1614,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var condiArgs []interface{} if len(condiBean) > 0 { - condiColNames, condiArgs = BuildConditions(session.Engine, session.Statement.RefTable, condiBean[0]) + condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0]) } var condition = "" @@ -1714,7 +1714,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.AutoMap(bean) session.Statement.RefTable = table - colNames, args := BuildConditions(session.Engine, table, bean) + colNames, args := buildConditions(session.Engine, table, bean) var condition = "" if session.Statement.WhereStr != "" { diff --git a/statement.go b/statement.go index e21467ca..7d68069f 100644 --- a/statement.go +++ b/statement.go @@ -35,14 +35,6 @@ type Statement struct { UseAutoTime bool } -func MakeArray(elem string, count int) []string { - res := make([]string, count) - for i := 0; i < count; i++ { - res[i] = elem - } - return res -} - func (statement *Statement) Init() { statement.RefTable = nil statement.Start = 0 @@ -76,7 +68,7 @@ func (statement *Statement) Where(querystring string, args ...interface{}) { } func (statement *Statement) Table(tableNameOrBean interface{}) { - t := Type(tableNameOrBean) + t := rType(tableNameOrBean) if t.Kind() == reflect.String { statement.AltTableName = tableNameOrBean.(string) } else if t.Kind() == reflect.Struct { @@ -84,7 +76,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) { } } -func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) { +func buildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns { @@ -196,7 +188,7 @@ func (statement *Statement) Id(id int64) { } func (statement *Statement) In(column string, args ...interface{}) { - inStr := fmt.Sprintf("%v IN (%v)", column, strings.Join(MakeArray("?", len(args)), ",")) + inStr := fmt.Sprintf("%v IN (%v)", column, strings.Join(makeArray("?", len(args)), ",")) if statement.WhereStr == "" { statement.WhereStr = inStr statement.Params = args @@ -337,7 +329,7 @@ func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { table := statement.Engine.AutoMap(bean) statement.RefTable = table - colNames, args := BuildConditions(statement.Engine, table, bean) + colNames, args := buildConditions(statement.Engine, table, bean) statement.ConditionStr = strings.Join(colNames, " and ") statement.BeanArgs = args @@ -374,7 +366,7 @@ func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) table := statement.Engine.AutoMap(bean) statement.RefTable = table - colNames, args := BuildConditions(statement.Engine, table, bean) + colNames, args := buildConditions(statement.Engine, table, bean) statement.ConditionStr = strings.Join(colNames, " and ") statement.BeanArgs = args return statement.genSelectSql(fmt.Sprintf("count(*) as %v", statement.Engine.Quote("total"))), append(statement.Params, statement.BeanArgs...)