diff --git a/engine.go b/engine.go index 146b2ee0..69bb1f11 100644 --- a/engine.go +++ b/engine.go @@ -18,6 +18,7 @@ const ( type Engine struct { Mapper IMapper + TagIdentifier string DriverName string DataSourceName string Tables map[reflect.Type]Table @@ -60,7 +61,7 @@ func (engine *Engine) Where(querystring string, args ...interface{}) *Engine { return engine } -func (engine *Engine) Id(id int) *Engine { +func (engine *Engine) Id(id int64) *Engine { engine.Statement.Id(id) return engine } @@ -70,6 +71,11 @@ func (engine *Engine) In(column string, args ...interface{}) *Engine { return engine } +func (engine *Engine) Table(tableName string) *Engine { + engine.Statement.Table(tableName) + return engine +} + func (engine *Engine) Limit(limit int, start ...int) *Engine { engine.Statement.Limit(limit, start...) return engine @@ -96,67 +102,6 @@ func (engine *Engine) Having(conditions string) *Engine { return engine } -func (e *Engine) genColumnStr(col *Column) string { - sql := "`" + col.Name + "` " - if col.SQLType == Date { - sql += " datetime " - } else { - if e.DriverName == SQLITE && col.IsPrimaryKey { - sql += "integer" - } else { - sql += col.SQLType.Name - } - if e.DriverName != SQLITE && col.SQLType != Text { - if col.SQLType != Decimal { - sql += "(" + strconv.Itoa(col.Length) + ")" - } else { - sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")" - } - } - } - - if col.Nullable { - sql += " NULL " - } else { - sql += " NOT NULL " - } - //fmt.Println(key) - if col.IsPrimaryKey { - sql += "PRIMARY KEY " - } - if col.IsAutoIncrement { - sql += e.AutoIncrement + " " - } - if col.IsUnique { - sql += "Unique " - } - return sql -} - -func (e *Engine) genCreateSQL(table *Table) string { - sql := "CREATE TABLE IF NOT EXISTS `" + table.Name + "` (" - //fmt.Println(session.Mapper.Obj2Table(session.PrimaryKey)) - for _, col := range table.Columns { - sql += e.genColumnStr(&col) - sql += "," - } - sql = sql[:len(sql)-2] + ");" - return sql -} - -func (e *Engine) genDropSQL(table *Table) string { - sql := "DROP TABLE IF EXISTS `" + table.Name + "`;" - return sql -} - -/* -map an object into a table object -*/ -func (engine *Engine) MapOne(bean interface{}) Table { - t := Type(bean) - return engine.MapType(t) -} - func (engine *Engine) AutoMapType(t reflect.Type) *Table { table, ok := engine.Tables[t] if !ok { @@ -179,7 +124,7 @@ func (engine *Engine) MapType(t reflect.Type) Table { for i := 0; i < t.NumField(); i++ { tag := t.Field(i).Tag - ormTagStr := tag.Get("xorm") + ormTagStr := tag.Get(engine.TagIdentifier) var col Column fieldType := t.Field(i).Type @@ -278,7 +223,7 @@ func (engine *Engine) Map(beans ...interface{}) (e error) { for _, bean := range beans { t := Type(bean) if _, ok := engine.Tables[t]; !ok { - engine.Tables[t] = engine.MapOne(bean) + engine.Tables[t] = engine.MapType(t) } } return @@ -294,11 +239,6 @@ func (engine *Engine) UnMap(beans ...interface{}) (e error) { return } -func (engine *Engine) Bean2Table(bean interface{}) *Table { - table := engine.Tables[Type(bean)] - return &table -} - func (e *Engine) DropAll() error { session, err := e.MakeSession() session.Begin() @@ -308,7 +248,8 @@ func (e *Engine) DropAll() error { } for _, table := range e.Tables { - sql := e.genDropSQL(&table) + e.Statement.RefTable = &table + sql := e.Statement.genDropSQL() _, err = session.Exec(sql) if err != nil { session.Rollback() @@ -321,15 +262,13 @@ func (e *Engine) DropAll() error { func (e *Engine) CreateTables(beans ...interface{}) error { session, err := e.MakeSession() session.Begin() + session.Statement = e.Statement defer session.Close() if err != nil { return err } for _, bean := range beans { - table := e.MapOne(bean) - e.Tables[table.Type] = table - sql := e.genCreateSQL(&table) - _, err = session.Exec(sql) + err = session.CreateTable(bean) if err != nil { session.Rollback() return err @@ -347,7 +286,8 @@ func (e *Engine) CreateAll() error { } for _, table := range e.Tables { - sql := e.genCreateSQL(&table) + e.Statement.RefTable = &table + sql := e.Statement.genCreateSQL() _, err = session.Exec(sql) if err != nil { session.Rollback() diff --git a/session.go b/session.go index 70ae825d..9a82b1aa 100644 --- a/session.go +++ b/session.go @@ -20,7 +20,7 @@ type Session struct { } func (session *Session) Init() { - session.Statement = Statement{} + session.Statement = Statement{Engine: session.Engine} session.IsAutoCommit = true session.IsCommitedOrRollbacked = false } @@ -34,11 +34,16 @@ func (session *Session) Where(querystring string, args ...interface{}) *Session return session } -func (session *Session) Id(id int) *Session { +func (session *Session) Id(id int64) *Session { session.Statement.Id(id) return session } +func (session *Session) Table(tableName string) *Session { + session.Statement.Table(tableName) + return session +} + func (session *Session) In(column string, args ...interface{}) *Session { session.Statement.In(column, args...) return session @@ -112,7 +117,7 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b return errors.New("expected a pointer to a struct") } - table := session.Engine.Bean2Table(obj) + table := session.Engine.Tables[Type(obj)] for key, data := range objMap { structField := dataStruct.FieldByName(table.Columns[key].FieldName) @@ -195,8 +200,8 @@ func (session *Session) innerExec(sql string, args ...interface{}) (sql.Result, } func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error) { - if session.Statement.Table != nil && session.Statement.Table.PrimaryKey != "" { - sql = strings.Replace(sql, "(id)", session.Statement.Table.PrimaryKey, -1) + if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { + sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) } if session.Engine.ShowSQL { fmt.Println(sql) @@ -207,20 +212,22 @@ func (session *Session) Exec(sql string, args ...interface{}) (sql.Result, error return session.Tx.Exec(sql, args...) } +func (session *Session) CreateTable(bean interface{}) error { + statement := session.Statement + defer statement.Init() + statement.RefTable = session.Engine.AutoMap(bean) + sql := statement.genCreateSQL() + _, err := session.Exec(sql) + return err +} + func (session *Session) Get(bean interface{}) error { statement := session.Statement - defer session.Statement.Init() + defer statement.Init() statement.Limit(1) - table := session.Engine.AutoMap(bean) - statement.Table = table - - colNames, args := session.BuildConditions(table, bean) - statement.ColumnStr = strings.Join(colNames, " and ") - statement.BeanArgs = args - - sql := statement.generateSql() - resultsSlice, err := session.Query(sql, append(statement.Params, statement.BeanArgs...)...) + sql, args := statement.genGetSql(bean) + resultsSlice, err := session.Query(sql, args...) if err != nil { return err @@ -242,14 +249,9 @@ func (session *Session) Get(bean interface{}) error { func (session *Session) Count(bean interface{}) (int64, error) { statement := session.Statement defer session.Statement.Init() - table := session.Engine.AutoMap(bean) - statement.Table = table + sql, args := statement.genCountSql(bean) - colNames, args := session.BuildConditions(table, bean) - statement.ColumnStr = strings.Join(colNames, " and ") - statement.BeanArgs = args - - resultsSlice, err := session.Query(statement.genCountSql(), append(statement.Params, statement.BeanArgs...)...) + resultsSlice, err := session.Query(sql, args...) if err != nil { return 0, err } @@ -273,10 +275,10 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) sliceElementType := sliceValue.Type().Elem() table := session.Engine.AutoMapType(sliceElementType) - statement.Table = table + statement.RefTable = table if len(condiBean) > 0 { - colNames, args := session.BuildConditions(table, condiBean[0]) + colNames, args := BuildConditions(session.Engine, table, condiBean[0]) statement.ColumnStr = strings.Join(colNames, " and ") statement.BeanArgs = args } @@ -300,8 +302,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } func (session *Session) Query(sql string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { - if session.Statement.Table != nil && session.Statement.Table.PrimaryKey != "" { - sql = strings.Replace(sql, "(id)", session.Statement.Table.PrimaryKey, -1) + if session.Statement.RefTable != nil && session.Statement.RefTable.PrimaryKey != "" { + sql = strings.Replace(sql, "(id)", session.Statement.RefTable.PrimaryKey, -1) } if session.Engine.ShowSQL { fmt.Println(sql) @@ -429,7 +431,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { sliceElementType := Type(bean) table := session.Engine.AutoMapType(sliceElementType) - session.Statement.Table = table + session.Statement.RefTable = table size := sliceValue.Len() @@ -470,7 +472,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", session.Engine.QuoteIdentifier, - table.Name, + session.Statement.TableName(), session.Engine.QuoteIdentifier, strings.Join(colNames, ", "), strings.Join(colMultiPlaces, "),(")) @@ -491,7 +493,7 @@ func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) { func (session *Session) InsertOne(bean interface{}) (int64, error) { table := session.Engine.AutoMap(bean) - session.Statement.Table = table + session.Statement.RefTable = table colNames := make([]string, 0) colPlaces := make([]string, 0) var args = make([]interface{}, 0) @@ -506,14 +508,14 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { colPlaces = append(colPlaces, "?") } - statement := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", + sql := fmt.Sprintf("INSERT INTO %v%v%v (%v) VALUES (%v)", session.Engine.QuoteIdentifier, - table.Name, + session.Statement.TableName(), session.Engine.QuoteIdentifier, strings.Join(colNames, ", "), strings.Join(colPlaces, ", ")) - res, err := session.Exec(statement, args...) + res, err := session.Exec(sql, args...) if err != nil { return -1, err } @@ -527,48 +529,15 @@ func (session *Session) InsertOne(bean interface{}) (int64, error) { return id, nil } -func (session *Session) BuildConditions(table *Table, bean interface{}) ([]string, []interface{}) { - colNames := make([]string, 0) - var args = make([]interface{}, 0) - for _, col := range table.Columns { - fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) - fieldType := reflect.TypeOf(fieldValue.Interface()) - val := fieldValue.Interface() - switch fieldType.Kind() { - case reflect.String: - if fieldValue.String() == "" { - continue - } - case reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) - if t.IsZero() { - continue - } - } - default: - continue - } - args = append(args, val) - colNames = append(colNames, session.Engine.QuoteIdentifier+col.Name+session.Engine.QuoteIdentifier+"=?") - } - - return colNames, args -} - func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int64, error) { table := session.Engine.AutoMap(bean) - session.Statement.Table = table - colNames, args := session.BuildConditions(table, bean) + session.Statement.RefTable = table + colNames, args := BuildConditions(session.Engine, table, bean) var condiColNames []string var condiArgs []interface{} if len(condiBean) > 0 { - condiColNames, condiArgs = session.BuildConditions(table, condiBean[0]) + condiColNames, condiArgs = BuildConditions(session.Engine, table, condiBean[0]) } var condition = "" @@ -590,7 +559,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 statement := fmt.Sprintf("UPDATE %v%v%v SET %v %v", session.Engine.QuoteIdentifier, - table.Name, + session.Statement.TableName(), session.Engine.QuoteIdentifier, strings.Join(colNames, ", "), condition) @@ -611,8 +580,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.AutoMap(bean) - session.Statement.Table = table - colNames, args := session.BuildConditions(table, bean) + session.Statement.RefTable = table + colNames, args := BuildConditions(session.Engine, table, bean) var condition = "" st := session.Statement @@ -629,7 +598,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) { statement := fmt.Sprintf("DELETE FROM %v%v%v %v", session.Engine.QuoteIdentifier, - table.Name, + session.Statement.TableName(), session.Engine.QuoteIdentifier, condition) diff --git a/statement.go b/statement.go index f27f3fda..c80c0cf6 100644 --- a/statement.go +++ b/statement.go @@ -2,22 +2,26 @@ package xorm import ( "fmt" + "reflect" + "strconv" "strings" + "time" ) type Statement struct { - Table *Table - Engine *Engine - Start int - LimitN int - WhereStr string - Params []interface{} - OrderStr string - JoinStr string - GroupByStr string - HavingStr string - ColumnStr string - BeanArgs []interface{} + RefTable *Table + Engine *Engine + Start int + LimitN int + WhereStr string + Params []interface{} + OrderStr string + JoinStr string + GroupByStr string + HavingStr string + ColumnStr string + AltTableName string + BeanArgs []interface{} } func MakeArray(elem string, count int) []string { @@ -29,7 +33,7 @@ func MakeArray(elem string, count int) []string { } func (statement *Statement) Init() { - statement.Table = nil + statement.RefTable = nil statement.Start = 0 statement.LimitN = 0 statement.WhereStr = "" @@ -39,6 +43,7 @@ func (statement *Statement) Init() { statement.GroupByStr = "" statement.HavingStr = "" statement.ColumnStr = "" + statement.AltTableName = "" statement.BeanArgs = make([]interface{}, 0) } @@ -47,7 +52,54 @@ func (statement *Statement) Where(querystring string, args ...interface{}) { statement.Params = args } -func (statement *Statement) Id(id int) { +func (statement *Statement) Table(tableName string) { + statement.AltTableName = tableName +} + +func BuildConditions(engine *Engine, table *Table, bean interface{}) ([]string, []interface{}) { + colNames := make([]string, 0) + var args = make([]interface{}, 0) + for _, col := range table.Columns { + fieldValue := reflect.Indirect(reflect.ValueOf(bean)).FieldByName(col.FieldName) + fieldType := reflect.TypeOf(fieldValue.Interface()) + val := fieldValue.Interface() + switch fieldType.Kind() { + case reflect.String: + if fieldValue.String() == "" { + continue + } + case reflect.Int, reflect.Int32, reflect.Int64: + if fieldValue.Int() == 0 { + continue + } + case reflect.Struct: + if fieldType == reflect.TypeOf(time.Now()) { + t := fieldValue.Interface().(time.Time) + if t.IsZero() { + continue + } + } + default: + continue + } + args = append(args, val) + colNames = append(colNames, engine.QuoteIdentifier+col.Name+engine.QuoteIdentifier+"=?") + } + + return colNames, args +} + +func (statement *Statement) TableName() string { + if statement.AltTableName != "" { + return statement.AltTableName + } + if statement.RefTable != nil { + return statement.RefTable.Name + } + return "" +} + +func (statement *Statement) Id(id int64) { if statement.WhereStr == "" { statement.WhereStr = "(id)=?" statement.Params = []interface{}{id} @@ -96,22 +148,100 @@ func (statement *Statement) Having(conditions string) { statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) } +func (statement *Statement) genColumnStr(col *Column) string { + sql := "`" + col.Name + "` " + if col.SQLType == Date { + sql += " datetime " + } else { + if statement.Engine.DriverName == SQLITE && col.IsPrimaryKey { + sql += "integer" + } else { + sql += col.SQLType.Name + } + if statement.Engine.DriverName != SQLITE && col.SQLType != Text { + if col.SQLType != Decimal { + sql += "(" + strconv.Itoa(col.Length) + ")" + } else { + sql += "(" + strconv.Itoa(col.Length) + "," + strconv.Itoa(col.Length2) + ")" + } + } + } + + if col.Nullable { + sql += " NULL " + } else { + sql += " NOT NULL " + } + //fmt.Println(key) + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + } + if col.IsAutoIncrement { + sql += statement.Engine.AutoIncrement + " " + } + if col.IsUnique { + sql += "Unique " + } + return sql +} + +func (statement *Statement) selectColumnStr() string { + table := statement.RefTable + colNames := make([]string, 0) + for _, col := range table.Columns { + colNames = append(colNames, statement.TableName()+"."+col.Name) + } + return strings.Join(colNames, ", ") +} + +func (statement *Statement) genCreateSQL() string { + sql := "CREATE TABLE IF NOT EXISTS `" + statement.TableName() + "` (" + for _, col := range statement.RefTable.Columns { + sql += statement.genColumnStr(&col) + sql += "," + } + sql = sql[:len(sql)-2] + ");" + return sql +} + +func (statement *Statement) genDropSQL() string { + sql := "DROP TABLE IF EXISTS `" + statement.TableName() + "`;" + return sql +} + func (statement Statement) generateSql() string { - columnStr := statement.Table.ColumnStr() + columnStr := statement.selectColumnStr() return statement.genSelectSql(columnStr) } -func (statement Statement) genCountSql() string { - return statement.genSelectSql("count(*) as total") +func (statement Statement) genGetSql(bean interface{}) (string, []interface{}) { + table := statement.Engine.AutoMap(bean) + statement.RefTable = table + + colNames, args := BuildConditions(statement.Engine, table, bean) + statement.ColumnStr = strings.Join(colNames, " and ") + statement.BeanArgs = args + + return statement.generateSql(), append(statement.Params, statement.BeanArgs...) +} + +func (statement Statement) genCountSql(bean interface{}) (string, []interface{}) { + table := statement.Engine.AutoMap(bean) + statement.RefTable = table + + colNames, args := BuildConditions(statement.Engine, table, bean) + statement.ColumnStr = strings.Join(colNames, " and ") + statement.BeanArgs = args + return statement.genSelectSql("count(*) as total"), append(statement.Params, statement.BeanArgs...) } func (statement Statement) genSelectSql(columnStr string) (a string) { if statement.Engine.DriverName == MSSQL { if statement.Start > 0 { a = fmt.Sprintf("select ROW_NUMBER() OVER(order by %v )as rownum,%v from %v", - statement.Table.PKColumn().Name, + statement.RefTable.PKColumn().Name, columnStr, - statement.Table.Name) + statement.TableName()) if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) if statement.ColumnStr != "" { @@ -127,7 +257,7 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { statement.Start, statement.LimitN) } else if statement.LimitN > 0 { - a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.Table.Name) + a = fmt.Sprintf("SELECT top %v %v FROM %v", statement.LimitN, columnStr, statement.TableName()) if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) if statement.ColumnStr != "" { @@ -146,7 +276,7 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } } else { - a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) + a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.TableName()) if statement.WhereStr != "" { a = fmt.Sprintf("%v WHERE %v", a, statement.WhereStr) if statement.ColumnStr != "" { @@ -166,7 +296,7 @@ func (statement Statement) genSelectSql(columnStr string) (a string) { } } } else { - a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.Table.Name) + a = fmt.Sprintf("SELECT %v FROM %v", columnStr, statement.TableName()) if statement.JoinStr != "" { a = fmt.Sprintf("%v %v", a, statement.JoinStr) } diff --git a/table.go b/table.go index 1ef24f05..d79e6f86 100644 --- a/table.go +++ b/table.go @@ -3,7 +3,7 @@ package xorm import ( "reflect" "strconv" - "strings" + //"strings" "time" ) @@ -70,14 +70,6 @@ type Table struct { PrimaryKey string } -func (table *Table) ColumnStr() string { - colNames := make([]string, 0) - for _, col := range table.Columns { - colNames = append(colNames, table.Name+"."+col.Name) - } - return strings.Join(colNames, ", ") -} - func (table *Table) PKColumn() Column { return table.Columns[table.PrimaryKey] } diff --git a/xorm.go b/xorm.go index ae5932aa..6c2d7911 100644 --- a/xorm.go +++ b/xorm.go @@ -11,6 +11,7 @@ func Create(driverName string, dataSourceName string) Engine { engine.Tables = make(map[reflect.Type]Table) engine.Statement.Engine = &engine engine.InsertMany = true + engine.TagIdentifier = "xorm" if driverName == SQLITE { engine.AutoIncrement = "AUTOINCREMENT" } else { diff --git a/xorm_test.go b/xorm_test.go index c05e4d92..381d771c 100644 --- a/xorm_test.go +++ b/xorm_test.go @@ -356,6 +356,64 @@ func combineTransaction(t *testing.T) { } } +func table(t *testing.T) { + engine.Table("user_user").CreateTables(&Userinfo{}) +} + +func createMultiTables(t *testing.T) { + session, err := engine.MakeSession() + defer session.Close() + if err != nil { + t.Error(err) + return + } + + user := &Userinfo{} + session.Begin() + for i := 0; i < 100; i++ { + err = session.Table(fmt.Sprintf("user_%v", i)).CreateTable(user) + if err != nil { + session.Rollback() + t.Error(err) + return + } + } + err = session.Commit() + if err != nil { + t.Error(err) + } +} + +func tableOp(t *testing.T) { + user := Userinfo{Username: "tablexiao", Departname: "dev", Alias: "lunny", Created: time.Now()} + tableName := fmt.Sprintf("user_%v", len(user.Username)) + id, err := engine.Table(tableName).Insert(&user) + if err != nil { + t.Error(err) + } + + err = engine.Table(tableName).Get(&Userinfo{Username: "tablexiao"}) + if err != nil { + t.Error(err) + } + + users := make([]Userinfo, 0) + err = engine.Table(tableName).Find(&users) + if err != nil { + t.Error(err) + } + + _, err = engine.Table(tableName).Id(id).Update(&Userinfo{Username: "tableda"}) + if err != nil { + t.Error(err) + } + + _, err = engine.Table(tableName).Id(id).Delete(&Userinfo{}) + if err != nil { + t.Error(err) + } +} + func TestMysql(t *testing.T) { engine = xorm.Create("mysql", "root:123@/test?charset=utf8") engine.ShowSQL = true @@ -381,6 +439,9 @@ func TestMysql(t *testing.T) { having(t) transaction(t) combineTransaction(t) + table(t) + createMultiTables(t) + tableOp(t) } func TestSqlite(t *testing.T) { @@ -408,4 +469,7 @@ func TestSqlite(t *testing.T) { having(t) transaction(t) combineTransaction(t) + table(t) + createMultiTables(t) + tableOp(t) }