Merge branch 'master' into dump-infer-from-type

This commit is contained in:
Andrew Thornton 2021-04-20 17:49:37 +01:00
commit e9a66ac1ad
No known key found for this signature in database
GPG Key ID: 3CDE74631F13A748
27 changed files with 250 additions and 97 deletions

View File

@ -15,6 +15,7 @@ warningCode = 1
[rule.if-return] [rule.if-return]
[rule.increment-decrement] [rule.increment-decrement]
[rule.var-naming] [rule.var-naming]
arguments = [["ID", "UID", "UUID", "URL", "JSON"], []]
[rule.var-declaration] [rule.var-declaration]
[rule.package-comments] [rule.package-comments]
[rule.range] [rule.range]
@ -23,3 +24,4 @@ warningCode = 1
[rule.unexported-return] [rule.unexported-return]
[rule.indent-error-flow] [rule.indent-error-flow]
[rule.errorf] [rule.errorf]
[rule.struct-tag]

View File

@ -13,22 +13,26 @@ import (
"io" "io"
) )
// md5 hash string // Md5 return md5 hash string
func Md5(str string) string { func Md5(str string) string {
m := md5.New() m := md5.New()
io.WriteString(m, str) io.WriteString(m, str)
return fmt.Sprintf("%x", m.Sum(nil)) return fmt.Sprintf("%x", m.Sum(nil))
} }
// Encode Encode data
func Encode(data interface{}) ([]byte, error) { func Encode(data interface{}) ([]byte, error) {
//return JsonEncode(data) //return JsonEncode(data)
return GobEncode(data) return GobEncode(data)
} }
// Decode decode data
func Decode(data []byte, to interface{}) error { func Decode(data []byte, to interface{}) error {
//return JsonDecode(data, to) //return JsonDecode(data, to)
return GobDecode(data, to) return GobDecode(data, to)
} }
// GobEncode encode data with gob
func GobEncode(data interface{}) ([]byte, error) { func GobEncode(data interface{}) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
enc := gob.NewEncoder(&buf) enc := gob.NewEncoder(&buf)
@ -39,12 +43,14 @@ func GobEncode(data interface{}) ([]byte, error) {
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// GobDecode decode data with gob
func GobDecode(data []byte, to interface{}) error { func GobDecode(data []byte, to interface{}) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
dec := gob.NewDecoder(buf) dec := gob.NewDecoder(buf)
return dec.Decode(to) return dec.Decode(to)
} }
// JsonEncode encode data with json
func JsonEncode(data interface{}) ([]byte, error) { func JsonEncode(data interface{}) ([]byte, error) {
val, err := json.Marshal(data) val, err := json.Marshal(data)
if err != nil { if err != nil {
@ -53,6 +59,7 @@ func JsonEncode(data interface{}) ([]byte, error) {
return val, nil return val, nil
} }
// JsonDecode decode data with json
func JsonDecode(data []byte, to interface{}) error { func JsonDecode(data []byte, to interface{}) error {
return json.Unmarshal(data, to) return json.Unmarshal(data, to)
} }

View File

@ -23,6 +23,7 @@ var (
DefaultCacheSize = 200 DefaultCacheSize = 200
) )
// MapToSlice map query and struct as sql and args
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
vv := reflect.ValueOf(mp) vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
@ -44,6 +45,7 @@ func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
return query, args, err return query, args, err
} }
// StructToSlice converts a query and struct as sql and args
func StructToSlice(query string, st interface{}) (string, []interface{}, error) { func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
vv := reflect.ValueOf(st) vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
@ -176,6 +178,7 @@ func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
return db.QueryMapContext(context.Background(), query, mp) return db.QueryMapContext(context.Background(), query, mp)
} }
// QueryStructContext query rows with struct
func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
@ -184,10 +187,12 @@ func (db *DB) QueryStructContext(ctx context.Context, query string, st interface
return db.QueryContext(ctx, query, args...) return db.QueryContext(ctx, query, args...)
} }
// QueryStruct query rows with struct
func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
return db.QueryStructContext(context.Background(), query, st) return db.QueryStructContext(context.Background(), query, st)
} }
// QueryRowContext query row with args
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := db.QueryContext(ctx, query, args...) rows, err := db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
@ -196,10 +201,12 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interfa
return &Row{rows, nil} return &Row{rows, nil}
} }
// QueryRow query row with args
func (db *DB) QueryRow(query string, args ...interface{}) *Row { func (db *DB) QueryRow(query string, args ...interface{}) *Row {
return db.QueryRowContext(context.Background(), query, args...) return db.QueryRowContext(context.Background(), query, args...)
} }
// QueryRowMapContext query row with map
func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp) query, args, err := MapToSlice(query, mp)
if err != nil { if err != nil {
@ -208,10 +215,12 @@ func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface
return db.QueryRowContext(ctx, query, args...) return db.QueryRowContext(ctx, query, args...)
} }
// QueryRowMap query row with map
func (db *DB) QueryRowMap(query string, mp interface{}) *Row { func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
return db.QueryRowMapContext(context.Background(), query, mp) return db.QueryRowMapContext(context.Background(), query, mp)
} }
// QueryRowStructContext query row with struct
func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
@ -220,6 +229,7 @@ func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interf
return db.QueryRowContext(ctx, query, args...) return db.QueryRowContext(ctx, query, args...)
} }
// QueryRowStruct query row with struct
func (db *DB) QueryRowStruct(query string, st interface{}) *Row { func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
return db.QueryRowStructContext(context.Background(), query, st) return db.QueryRowStructContext(context.Background(), query, st)
} }
@ -239,10 +249,12 @@ func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{})
return db.ExecContext(ctx, query, args...) return db.ExecContext(ctx, query, args...)
} }
// ExecMap exec query with map
func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
return db.ExecMapContext(context.Background(), query, mp) return db.ExecMapContext(context.Background(), query, mp)
} }
// ExecStructContext exec query with map
func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
@ -251,6 +263,7 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{
return db.ExecContext(ctx, query, args...) return db.ExecContext(ctx, query, args...)
} }
// ExecContext exec query with args
func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
hookCtx := contexts.NewContextHook(ctx, query, args) hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := db.beforeProcess(hookCtx) ctx, err := db.beforeProcess(hookCtx)
@ -265,6 +278,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}
return res, nil return res, nil
} }
// ExecStruct exec query with struct
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
return db.ExecStructContext(context.Background(), query, st) return db.ExecStructContext(context.Background(), query, st)
} }
@ -288,6 +302,7 @@ func (db *DB) afterProcess(c *contexts.ContextHook) error {
return err return err
} }
// AddHook adds hook
func (db *DB) AddHook(h ...contexts.Hook) { func (db *DB) AddHook(h ...contexts.Hook) {
db.hooks.AddHook(h...) db.hooks.AddHook(h...)
} }

View File

@ -21,7 +21,7 @@ import (
var ( var (
dbtype = flag.String("dbtype", "sqlite3", "database type") dbtype = flag.String("dbtype", "sqlite3", "database type")
dbConn = flag.String("dbConn", "./db_test.db", "database connect string") dbConn = flag.String("dbConn", "./db_test.db", "database connect string")
createTableSql string createTableSQL string
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -29,12 +29,12 @@ func TestMain(m *testing.M) {
switch *dbtype { switch *dbtype {
case "sqlite3", "sqlite": case "sqlite3", "sqlite":
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " +
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
case "mysql": case "mysql":
fallthrough fallthrough
default: default:
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " +
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
} }
@ -66,7 +66,7 @@ func BenchmarkOriQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -121,7 +121,7 @@ func BenchmarkStructQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -166,7 +166,7 @@ func BenchmarkStruct2Query(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -212,7 +212,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -270,7 +270,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -321,7 +321,7 @@ func BenchmarkSliceStringQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -372,7 +372,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -426,7 +426,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -473,7 +473,7 @@ func BenchmarkMapStringQuery(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -519,7 +519,7 @@ func BenchmarkExec(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -544,7 +544,7 @@ func BenchmarkExecMap(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }
@ -577,7 +577,7 @@ func TestExecMap(t *testing.T) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -620,7 +620,7 @@ func TestExecStruct(t *testing.T) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -663,7 +663,7 @@ func BenchmarkExecStruct(b *testing.B) {
} }
defer db.Close() defer db.Close()
_, err = db.Exec(createTableSql) _, err = db.Exec(createTableSQL)
if err != nil { if err != nil {
b.Error(err) b.Error(err)
} }

View File

@ -11,11 +11,13 @@ import (
"sync" "sync"
) )
// Rows represents rows of table
type Rows struct { type Rows struct {
*sql.Rows *sql.Rows
db *DB db *DB
} }
// ToMapString returns all records
func (rs *Rows) ToMapString() ([]map[string]string, error) { func (rs *Rows) ToMapString() ([]map[string]string, error) {
cols, err := rs.Columns() cols, err := rs.Columns()
if err != nil { if err != nil {
@ -34,7 +36,7 @@ func (rs *Rows) ToMapString() ([]map[string]string, error) {
return results, nil return results, nil
} }
// scan data to a struct's pointer according field index // ScanStructByIndex scan data to a struct's pointer according field index
func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
if len(dest) == 0 { if len(dest) == 0 {
return errors.New("at least one struct") return errors.New("at least one struct")
@ -94,7 +96,7 @@ func fieldByName(v reflect.Value, name string) reflect.Value {
return reflect.Zero(t) return reflect.Zero(t)
} }
// scan data to a struct's pointer according field name // ScanStructByName scan data to a struct's pointer according field name
func (rs *Rows) ScanStructByName(dest interface{}) error { func (rs *Rows) ScanStructByName(dest interface{}) error {
vv := reflect.ValueOf(dest) vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
@ -120,7 +122,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error {
return rs.Rows.Scan(newDest...) return rs.Rows.Scan(newDest...)
} }
// scan data to a slice's pointer, slice's length should equal to columns' number // ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number
func (rs *Rows) ScanSlice(dest interface{}) error { func (rs *Rows) ScanSlice(dest interface{}) error {
vv := reflect.ValueOf(dest) vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
@ -155,7 +157,7 @@ func (rs *Rows) ScanSlice(dest interface{}) error {
return nil return nil
} }
// scan data to a map's pointer // ScanMap scan data to a map's pointer
func (rs *Rows) ScanMap(dest interface{}) error { func (rs *Rows) ScanMap(dest interface{}) error {
vv := reflect.ValueOf(dest) vv := reflect.ValueOf(dest)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
@ -187,6 +189,7 @@ func (rs *Rows) ScanMap(dest interface{}) error {
return nil return nil
} }
// Row reprents a row of a tab
type Row struct { type Row struct {
rows *Rows rows *Rows
// One of these two will be non-nil: // One of these two will be non-nil:
@ -205,6 +208,7 @@ func NewRow(rows *Rows, err error) *Row {
return &Row{rows, err} return &Row{rows, err}
} }
// Columns returns all columns of the row
func (row *Row) Columns() ([]string, error) { func (row *Row) Columns() ([]string, error) {
if row.err != nil { if row.err != nil {
return nil, row.err return nil, row.err
@ -212,6 +216,7 @@ func (row *Row) Columns() ([]string, error) {
return row.rows.Columns() return row.rows.Columns()
} }
// Scan retrieves all row column values
func (row *Row) Scan(dest ...interface{}) error { func (row *Row) Scan(dest ...interface{}) error {
if row.err != nil { if row.err != nil {
return row.err return row.err
@ -238,6 +243,7 @@ func (row *Row) Scan(dest ...interface{}) error {
return row.rows.Close() return row.rows.Close()
} }
// ScanStructByName retrieves all row column values into a struct
func (row *Row) ScanStructByName(dest interface{}) error { func (row *Row) ScanStructByName(dest interface{}) error {
if row.err != nil { if row.err != nil {
return row.err return row.err
@ -258,6 +264,7 @@ func (row *Row) ScanStructByName(dest interface{}) error {
return row.rows.Close() return row.rows.Close()
} }
// ScanStructByIndex retrieves all row column values into a struct
func (row *Row) ScanStructByIndex(dest interface{}) error { func (row *Row) ScanStructByIndex(dest interface{}) error {
if row.err != nil { if row.err != nil {
return row.err return row.err
@ -278,7 +285,7 @@ func (row *Row) ScanStructByIndex(dest interface{}) error {
return row.rows.Close() return row.rows.Close()
} }
// scan data to a slice's pointer, slice's length should equal to columns' number // ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number
func (row *Row) ScanSlice(dest interface{}) error { func (row *Row) ScanSlice(dest interface{}) error {
if row.err != nil { if row.err != nil {
return row.err return row.err
@ -300,7 +307,7 @@ func (row *Row) ScanSlice(dest interface{}) error {
return row.rows.Close() return row.rows.Close()
} }
// scan data to a map's pointer // ScanMap scan data to a map's pointer
func (row *Row) ScanMap(dest interface{}) error { func (row *Row) ScanMap(dest interface{}) error {
if row.err != nil { if row.err != nil {
return row.err return row.err
@ -322,6 +329,7 @@ func (row *Row) ScanMap(dest interface{}) error {
return row.rows.Close() return row.rows.Close()
} }
// ToMapString returns all clumns of this record
func (row *Row) ToMapString() (map[string]string, error) { func (row *Row) ToMapString() (map[string]string, error) {
cols, err := row.Columns() cols, err := row.Columns()
if err != nil { if err != nil {

View File

@ -10,12 +10,14 @@ import (
"time" "time"
) )
// NullTime defines a customize type NullTime
type NullTime time.Time type NullTime time.Time
var ( var (
_ driver.Valuer = NullTime{} _ driver.Valuer = NullTime{}
) )
// Scan implements driver.Valuer
func (ns *NullTime) Scan(value interface{}) error { func (ns *NullTime) Scan(value interface{}) error {
if value == nil { if value == nil {
return nil return nil
@ -58,9 +60,11 @@ func convertTime(dest *NullTime, src interface{}) error {
return nil return nil
} }
// EmptyScanner represents an empty scanner
type EmptyScanner struct { type EmptyScanner struct {
} }
// Scan implements
func (EmptyScanner) Scan(src interface{}) error { func (EmptyScanner) Scan(src interface{}) error {
return nil return nil
} }

View File

@ -21,6 +21,7 @@ type Stmt struct {
query string query string
} }
// PrepareContext creates a prepare statement
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
names := make(map[string]int) names := make(map[string]int)
var i int var i int
@ -42,10 +43,12 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
return &Stmt{stmt, db, names, query}, nil return &Stmt{stmt, db, names, query}, nil
} }
// Prepare creates a prepare statement
func (db *DB) Prepare(query string) (*Stmt, error) { func (db *DB) Prepare(query string) (*Stmt, error) {
return db.PrepareContext(context.Background(), query) return db.PrepareContext(context.Background(), query)
} }
// ExecMapContext execute with map
func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) {
vv := reflect.ValueOf(mp) vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
@ -59,10 +62,12 @@ func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result,
return s.ExecContext(ctx, args...) return s.ExecContext(ctx, args...)
} }
// ExecMap executes with map
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
return s.ExecMapContext(context.Background(), mp) return s.ExecMapContext(context.Background(), mp)
} }
// ExecStructContext executes with struct
func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) {
vv := reflect.ValueOf(st) vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
@ -76,10 +81,12 @@ func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Resul
return s.ExecContext(ctx, args...) return s.ExecContext(ctx, args...)
} }
// ExecStruct executes with struct
func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
return s.ExecStructContext(context.Background(), st) return s.ExecStructContext(context.Background(), st)
} }
// ExecContext with args
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
hookCtx := contexts.NewContextHook(ctx, s.query, args) hookCtx := contexts.NewContextHook(ctx, s.query, args)
ctx, err := s.db.beforeProcess(hookCtx) ctx, err := s.db.beforeProcess(hookCtx)
@ -94,6 +101,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
return res, nil return res, nil
} }
// QueryContext query with args
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
hookCtx := contexts.NewContextHook(ctx, s.query, args) hookCtx := contexts.NewContextHook(ctx, s.query, args)
ctx, err := s.db.beforeProcess(hookCtx) ctx, err := s.db.beforeProcess(hookCtx)
@ -108,10 +116,12 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
return &Rows{rows, s.db}, nil return &Rows{rows, s.db}, nil
} }
// Query query with args
func (s *Stmt) Query(args ...interface{}) (*Rows, error) { func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...) return s.QueryContext(context.Background(), args...)
} }
// QueryMapContext query with map
func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) {
vv := reflect.ValueOf(mp) vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
@ -126,10 +136,12 @@ func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, erro
return s.QueryContext(ctx, args...) return s.QueryContext(ctx, args...)
} }
// QueryMap query with map
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
return s.QueryMapContext(context.Background(), mp) return s.QueryMapContext(context.Background(), mp)
} }
// QueryStructContext query with struct
func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) {
vv := reflect.ValueOf(st) vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
@ -144,19 +156,23 @@ func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, e
return s.QueryContext(ctx, args...) return s.QueryContext(ctx, args...)
} }
// QueryStruct query with struct
func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
return s.QueryStructContext(context.Background(), st) return s.QueryStructContext(context.Background(), st)
} }
// QueryRowContext query row with args
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
rows, err := s.QueryContext(ctx, args...) rows, err := s.QueryContext(ctx, args...)
return &Row{rows, err} return &Row{rows, err}
} }
// QueryRow query row with args
func (s *Stmt) QueryRow(args ...interface{}) *Row { func (s *Stmt) QueryRow(args ...interface{}) *Row {
return s.QueryRowContext(context.Background(), args...) return s.QueryRowContext(context.Background(), args...)
} }
// QueryRowMapContext query row with map
func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
vv := reflect.ValueOf(mp) vv := reflect.ValueOf(mp)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
@ -171,10 +187,12 @@ func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
return s.QueryRowContext(ctx, args...) return s.QueryRowContext(ctx, args...)
} }
// QueryRowMap query row with map
func (s *Stmt) QueryRowMap(mp interface{}) *Row { func (s *Stmt) QueryRowMap(mp interface{}) *Row {
return s.QueryRowMapContext(context.Background(), mp) return s.QueryRowMapContext(context.Background(), mp)
} }
// QueryRowStructContext query row with struct
func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
vv := reflect.ValueOf(st) vv := reflect.ValueOf(st)
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
@ -189,6 +207,7 @@ func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
return s.QueryRowContext(ctx, args...) return s.QueryRowContext(ctx, args...)
} }
// QueryRowStruct query row with struct
func (s *Stmt) QueryRowStruct(st interface{}) *Row { func (s *Stmt) QueryRowStruct(st interface{}) *Row {
return s.QueryRowStructContext(context.Background(), st) return s.QueryRowStructContext(context.Background(), st)
} }

View File

@ -22,6 +22,7 @@ type Tx struct {
ctx context.Context ctx context.Context
} }
// BeginTx begin a transaction with option
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil)
ctx, err := db.beforeProcess(hookCtx) ctx, err := db.beforeProcess(hookCtx)
@ -36,10 +37,12 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
return &Tx{tx, db, ctx}, nil return &Tx{tx, db, ctx}, nil
} }
// Begin begins a transaction
func (db *DB) Begin() (*Tx, error) { func (db *DB) Begin() (*Tx, error) {
return db.BeginTx(context.Background(), nil) return db.BeginTx(context.Background(), nil)
} }
// Commit submit the transaction
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
hookCtx := contexts.NewContextHook(tx.ctx, "COMMIT", nil) hookCtx := contexts.NewContextHook(tx.ctx, "COMMIT", nil)
ctx, err := tx.db.beforeProcess(hookCtx) ctx, err := tx.db.beforeProcess(hookCtx)
@ -48,12 +51,10 @@ func (tx *Tx) Commit() error {
} }
err = tx.Tx.Commit() err = tx.Tx.Commit()
hookCtx.End(ctx, nil, err) hookCtx.End(ctx, nil, err)
if err := tx.db.afterProcess(hookCtx); err != nil { return tx.db.afterProcess(hookCtx)
return err
}
return nil
} }
// Rollback rollback the transaction
func (tx *Tx) Rollback() error { func (tx *Tx) Rollback() error {
hookCtx := contexts.NewContextHook(tx.ctx, "ROLLBACK", nil) hookCtx := contexts.NewContextHook(tx.ctx, "ROLLBACK", nil)
ctx, err := tx.db.beforeProcess(hookCtx) ctx, err := tx.db.beforeProcess(hookCtx)
@ -62,12 +63,10 @@ func (tx *Tx) Rollback() error {
} }
err = tx.Tx.Rollback() err = tx.Tx.Rollback()
hookCtx.End(ctx, nil, err) hookCtx.End(ctx, nil, err)
if err := tx.db.afterProcess(hookCtx); err != nil { return tx.db.afterProcess(hookCtx)
return err
}
return nil
} }
// PrepareContext prepare the query
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
names := make(map[string]int) names := make(map[string]int)
var i int var i int
@ -89,19 +88,23 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
return &Stmt{stmt, tx.db, names, query}, nil return &Stmt{stmt, tx.db, names, query}, nil
} }
// Prepare prepare the query
func (tx *Tx) Prepare(query string) (*Stmt, error) { func (tx *Tx) Prepare(query string) (*Stmt, error) {
return tx.PrepareContext(context.Background(), query) return tx.PrepareContext(context.Background(), query)
} }
// StmtContext creates Stmt with context
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt) stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt)
return stmt return stmt
} }
// Stmt creates Stmt
func (tx *Tx) Stmt(stmt *Stmt) *Stmt { func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
return tx.StmtContext(context.Background(), stmt) return tx.StmtContext(context.Background(), stmt)
} }
// ExecMapContext executes query with args in a map
func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) {
query, args, err := MapToSlice(query, mp) query, args, err := MapToSlice(query, mp)
if err != nil { if err != nil {
@ -110,10 +113,12 @@ func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{})
return tx.ExecContext(ctx, query, args...) return tx.ExecContext(ctx, query, args...)
} }
// ExecMap executes query with args in a map
func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
return tx.ExecMapContext(context.Background(), query, mp) return tx.ExecMapContext(context.Background(), query, mp)
} }
// ExecStructContext executes query with args in a struct
func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
@ -122,6 +127,7 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{
return tx.ExecContext(ctx, query, args...) return tx.ExecContext(ctx, query, args...)
} }
// ExecContext executes a query with args
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
hookCtx := contexts.NewContextHook(ctx, query, args) hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := tx.db.beforeProcess(hookCtx) ctx, err := tx.db.beforeProcess(hookCtx)
@ -136,10 +142,12 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}
return res, err return res, err
} }
// ExecStruct executes query with args in a struct
func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
return tx.ExecStructContext(context.Background(), query, st) return tx.ExecStructContext(context.Background(), query, st)
} }
// QueryContext query with args
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
hookCtx := contexts.NewContextHook(ctx, query, args) hookCtx := contexts.NewContextHook(ctx, query, args)
ctx, err := tx.db.beforeProcess(hookCtx) ctx, err := tx.db.beforeProcess(hookCtx)
@ -157,10 +165,12 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{
return &Rows{rows, tx.db}, nil return &Rows{rows, tx.db}, nil
} }
// Query query with args
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
return tx.QueryContext(context.Background(), query, args...) return tx.QueryContext(context.Background(), query, args...)
} }
// QueryMapContext query with args in a map
func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) {
query, args, err := MapToSlice(query, mp) query, args, err := MapToSlice(query, mp)
if err != nil { if err != nil {
@ -169,10 +179,12 @@ func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{})
return tx.QueryContext(ctx, query, args...) return tx.QueryContext(ctx, query, args...)
} }
// QueryMap query with args in a map
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
return tx.QueryMapContext(context.Background(), query, mp) return tx.QueryMapContext(context.Background(), query, mp)
} }
// QueryStructContext query with args in struct
func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
@ -181,19 +193,23 @@ func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface
return tx.QueryContext(ctx, query, args...) return tx.QueryContext(ctx, query, args...)
} }
// QueryStruct query with args in struct
func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
return tx.QueryStructContext(context.Background(), query, st) return tx.QueryStructContext(context.Background(), query, st)
} }
// QueryRowContext query one row with args
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
return &Row{rows, err} return &Row{rows, err}
} }
// QueryRow query one row with args
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
return tx.QueryRowContext(context.Background(), query, args...) return tx.QueryRowContext(context.Background(), query, args...)
} }
// QueryRowMapContext query one row with args in a map
func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row {
query, args, err := MapToSlice(query, mp) query, args, err := MapToSlice(query, mp)
if err != nil { if err != nil {
@ -202,10 +218,12 @@ func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface
return tx.QueryRowContext(ctx, query, args...) return tx.QueryRowContext(ctx, query, args...)
} }
// QueryRowMap query one row with args in a map
func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
return tx.QueryRowMapContext(context.Background(), query, mp) return tx.QueryRowMapContext(context.Background(), query, mp)
} }
// QueryRowStructContext query one row with args in struct
func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row {
query, args, err := StructToSlice(query, st) query, args, err := StructToSlice(query, st)
if err != nil { if err != nil {
@ -214,6 +232,7 @@ func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interf
return tx.QueryRowContext(ctx, query, args...) return tx.QueryRowContext(ctx, query, args...)
} }
// QueryRowStruct query one row with args in struct
func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
return tx.QueryRowStructContext(context.Background(), query, st) return tx.QueryRowStructContext(context.Background(), query, st)
} }

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
) )
// Driver represents a database driver
type Driver interface { type Driver interface {
Parse(string, string) (*URI, error) Parse(string, string) (*URI, error)
} }
@ -16,6 +17,7 @@ var (
drivers = map[string]Driver{} drivers = map[string]Driver{}
) )
// RegisterDriver register a driver
func RegisterDriver(driverName string, driver Driver) { func RegisterDriver(driverName string, driver Driver) {
if driver == nil { if driver == nil {
panic("core: Register driver is nil") panic("core: Register driver is nil")
@ -26,10 +28,12 @@ func RegisterDriver(driverName string, driver Driver) {
drivers[driverName] = driver drivers[driverName] = driver
} }
// QueryDriver query a driver with name
func QueryDriver(driverName string) Driver { func QueryDriver(driverName string) Driver {
return drivers[driverName] return drivers[driverName]
} }
// RegisteredDriverSize returned all drivers's length
func RegisteredDriverSize() int { func RegisteredDriverSize() int {
return len(drivers) return len(drivers)
} }
@ -38,7 +42,7 @@ func RegisteredDriverSize() int {
func OpenDialect(driverName, connstr string) (Dialect, error) { func OpenDialect(driverName, connstr string) (Dialect, error) {
driver := QueryDriver(driverName) driver := QueryDriver(driverName)
if driver == nil { if driver == nil {
return nil, fmt.Errorf("Unsupported driver name: %v", driverName) return nil, fmt.Errorf("unsupported driver name: %v", driverName)
} }
uri, err := driver.Parse(driverName, connstr) uri, err := driver.Parse(driverName, connstr)
@ -48,7 +52,7 @@ func OpenDialect(driverName, connstr string) (Dialect, error) {
dialect := QueryDialect(uri.DBType) dialect := QueryDialect(uri.DBType)
if dialect == nil { if dialect == nil {
return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) return nil, fmt.Errorf("unsupported dialect type: %v", uri.DBType)
} }
dialect.Init(uri) dialect.Init(uri)

View File

@ -38,6 +38,7 @@ func convertQuestionMark(sql, prefix string, start int) string {
return buf.String() return buf.String()
} }
// Do implements Filter
func (s *SeqFilter) Do(sql string) string { func (s *SeqFilter) Do(sql string) string {
return convertQuestionMark(sql, s.Prefix, s.Start) return convertQuestionMark(sql, s.Prefix, s.Start)
} }

View File

@ -38,6 +38,7 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{}
return return
} }
// FormatColumnTime format column time
func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) { func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) {
if t.IsZero() { if t.IsZero() {
if col.Nullable { if col.Nullable {

View File

@ -97,6 +97,7 @@ func TestDeleted(t *testing.T) {
// Test normal Find() // Test normal Find()
var records1 []Deleted var records1 []Deleted
err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &Deleted{}) err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &Deleted{})
assert.NoError(t, err)
assert.EqualValues(t, 3, len(records1)) assert.EqualValues(t, 3, len(records1))
// Test normal Get() // Test normal Get()
@ -132,6 +133,7 @@ func TestDeleted(t *testing.T) {
record2 := &Deleted{} record2 := &Deleted{}
has, err = testEngine.ID(2).Get(record2) has, err = testEngine.ID(2).Get(record2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, has)
assert.True(t, record2.DeletedAt.IsZero()) assert.True(t, record2.DeletedAt.IsZero())
// Test find all records whatever `deleted`. // Test find all records whatever `deleted`.

View File

@ -166,10 +166,7 @@ func createEngine(dbType, connStr string) error {
for _, table := range tables { for _, table := range tables {
tableNames = append(tableNames, table.Name) tableNames = append(tableNames, table.Name)
} }
if err = testEngine.DropTables(tableNames...); err != nil { return testEngine.DropTables(tableNames...)
return err
}
return nil
} }
// PrepareEngine prepare tests ORM engine // PrepareEngine prepare tests ORM engine

View File

@ -12,6 +12,7 @@ import (
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// ConvertIDSQL converts SQL with id
func (statement *Statement) ConvertIDSQL(sqlStr string) string { func (statement *Statement) ConvertIDSQL(sqlStr string) string {
if statement.RefTable != nil { if statement.RefTable != nil {
cols := statement.RefTable.PKColumns() cols := statement.RefTable.PKColumns()
@ -37,6 +38,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string {
return "" return ""
} }
// ConvertUpdateSQL converts update SQL
func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) {
if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
return "", "" return "", ""

View File

@ -12,6 +12,7 @@ import (
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// ErrUnsupportedExprType represents an error with unsupported express type
type ErrUnsupportedExprType struct { type ErrUnsupportedExprType struct {
tp string tp string
} }

View File

@ -14,6 +14,7 @@ import (
"xorm.io/xorm/schemas" "xorm.io/xorm/schemas"
) )
// GenQuerySQL generate query SQL
func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) {
if len(sqlOrArgs) > 0 { if len(sqlOrArgs) > 0 {
return statement.ConvertSQLOrArgs(sqlOrArgs...) return statement.ConvertSQLOrArgs(sqlOrArgs...)
@ -72,6 +73,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int
return sqlStr, args, nil return sqlStr, args, nil
} }
// GenSumSQL generates sum SQL
func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
if statement.RawSQL != "" { if statement.RawSQL != "" {
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
@ -102,6 +104,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri
return sqlStr, append(statement.joinArgs, condArgs...), nil return sqlStr, append(statement.joinArgs, condArgs...), nil
} }
// GenGetSQL generates Get SQL
func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) {
v := rValue(bean) v := rValue(bean)
isStruct := v.Kind() == reflect.Struct isStruct := v.Kind() == reflect.Struct
@ -316,6 +319,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB
return buf.String(), condArgs, nil return buf.String(), condArgs, nil
} }
// GenExistSQL generates Exist SQL
func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) {
if statement.RawSQL != "" { if statement.RawSQL != "" {
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil
@ -385,6 +389,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac
return sqlStr, args, nil return sqlStr, args, nil
} }
// GenFindSQL generates Find SQL
func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) {
if statement.RawSQL != "" { if statement.RawSQL != "" {
return statement.GenRawSQL(), statement.RawParams, nil return statement.GenRawSQL(), statement.RawParams, nil

View File

@ -90,19 +90,17 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ
return statement return statement
} }
// SetTableName set table name
func (statement *Statement) SetTableName(tableName string) { func (statement *Statement) SetTableName(tableName string) {
statement.tableName = tableName statement.tableName = tableName
} }
func (statement *Statement) omitStr() string {
return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,")
}
// GenRawSQL generates correct raw sql // GenRawSQL generates correct raw sql
func (statement *Statement) GenRawSQL() string { func (statement *Statement) GenRawSQL() string {
return statement.ReplaceQuote(statement.RawSQL) return statement.ReplaceQuote(statement.RawSQL)
} }
// GenCondSQL generates condition SQL
func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) {
condSQL, condArgs, err := builder.ToSQL(condOrBuilder) condSQL, condArgs, err := builder.ToSQL(condOrBuilder)
if err != nil { if err != nil {
@ -111,6 +109,7 @@ func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []int
return statement.ReplaceQuote(condSQL), condArgs, nil return statement.ReplaceQuote(condSQL), condArgs, nil
} }
// ReplaceQuote replace sql key words with quote
func (statement *Statement) ReplaceQuote(sql string) string { func (statement *Statement) ReplaceQuote(sql string) string {
if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL ||
statement.dialect.URI().DBType == schemas.SQLITE { statement.dialect.URI().DBType == schemas.SQLITE {
@ -119,11 +118,12 @@ func (statement *Statement) ReplaceQuote(sql string) string {
return statement.dialect.Quoter().Replace(sql) return statement.dialect.Quoter().Replace(sql)
} }
// SetContextCache sets context cache
func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) {
statement.Context = ctxCache statement.Context = ctxCache
} }
// Init reset all the statement's fields // Reset reset all the statement's fields
func (statement *Statement) Reset() { func (statement *Statement) Reset() {
statement.RefTable = nil statement.RefTable = nil
statement.Start = 0 statement.Start = 0
@ -163,7 +163,7 @@ func (statement *Statement) Reset() {
statement.LastError = nil statement.LastError = nil
} }
// NoAutoCondition if you do not want convert bean's field as query condition, then use this function // SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function
func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement {
statement.NoAutoCondition = true statement.NoAutoCondition = true
if len(no) > 0 { if len(no) > 0 {
@ -271,6 +271,7 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement
return statement return statement
} }
// SetRefValue set ref value
func (statement *Statement) SetRefValue(v reflect.Value) error { func (statement *Statement) SetRefValue(v reflect.Value) error {
var err error var err error
statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v))
@ -285,6 +286,7 @@ func rValue(bean interface{}) reflect.Value {
return reflect.Indirect(reflect.ValueOf(bean)) return reflect.Indirect(reflect.ValueOf(bean))
} }
// SetRefBean set ref bean
func (statement *Statement) SetRefBean(bean interface{}) error { func (statement *Statement) SetRefBean(bean interface{}) error {
var err error var err error
statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean))
@ -390,6 +392,7 @@ func (statement *Statement) Cols(columns ...string) *Statement {
return statement return statement
} }
// ColumnStr returns column string
func (statement *Statement) ColumnStr() string { func (statement *Statement) ColumnStr() string {
return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") return statement.dialect.Quoter().Join(statement.ColumnMap, ", ")
} }
@ -493,11 +496,12 @@ func (statement *Statement) Asc(colNames ...string) *Statement {
return statement return statement
} }
// Conds returns condtions
func (statement *Statement) Conds() builder.Cond { func (statement *Statement) Conds() builder.Cond {
return statement.cond return statement.cond
} }
// Table tempororily set table name, the parameter could be a string or a pointer of struct // SetTable tempororily set table name, the parameter could be a string or a pointer of struct
func (statement *Statement) SetTable(tableNameOrBean interface{}) error { func (statement *Statement) SetTable(tableNameOrBean interface{}) error {
v := rValue(tableNameOrBean) v := rValue(tableNameOrBean)
t := v.Type() t := v.Type()
@ -564,7 +568,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement return statement
} }
// tbName get some table's table name // tbNameNoSchema get some table's table name
func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { func (statement *Statement) tbNameNoSchema(table *schemas.Table) string {
if len(statement.AltTableName) > 0 { if len(statement.AltTableName) > 0 {
return statement.AltTableName return statement.AltTableName
@ -585,12 +589,13 @@ func (statement *Statement) Having(conditions string) *Statement {
return statement return statement
} }
// Unscoped always disable struct tag "deleted" // SetUnscoped always disable struct tag "deleted"
func (statement *Statement) SetUnscoped() *Statement { func (statement *Statement) SetUnscoped() *Statement {
statement.unscoped = true statement.unscoped = true
return statement return statement
} }
// GetUnscoped return true if it's unscoped
func (statement *Statement) GetUnscoped() bool { func (statement *Statement) GetUnscoped() bool {
return statement.unscoped return statement.unscoped
} }
@ -636,6 +641,7 @@ func (statement *Statement) genColumnStr() string {
return buf.String() return buf.String()
} }
// GenCreateTableSQL generated create table SQL
func (statement *Statement) GenCreateTableSQL() []string { func (statement *Statement) GenCreateTableSQL() []string {
statement.RefTable.StoreEngine = statement.StoreEngine statement.RefTable.StoreEngine = statement.StoreEngine
statement.RefTable.Charset = statement.Charset statement.RefTable.Charset = statement.Charset
@ -643,6 +649,7 @@ func (statement *Statement) GenCreateTableSQL() []string {
return s return s
} }
// GenIndexSQL generated create index SQL
func (statement *Statement) GenIndexSQL() []string { func (statement *Statement) GenIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
@ -659,6 +666,7 @@ func uniqueName(tableName, uqeName string) string {
return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
} }
// GenUniqueSQL generates unique SQL
func (statement *Statement) GenUniqueSQL() []string { func (statement *Statement) GenUniqueSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
@ -671,6 +679,7 @@ func (statement *Statement) GenUniqueSQL() []string {
return sqls return sqls
} }
// GenDelIndexSQL generate delete index SQL
func (statement *Statement) GenDelIndexSQL() []string { func (statement *Statement) GenDelIndexSQL() []string {
var sqls []string var sqls []string
tbName := statement.TableName() tbName := statement.TableName()
@ -896,6 +905,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{},
return builder.And(conds...), nil return builder.And(conds...), nil
} }
// BuildConds builds condition
func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
@ -911,12 +921,10 @@ func (statement *Statement) mergeConds(bean interface{}) error {
statement.cond = statement.cond.And(autoCond) statement.cond = statement.cond.And(autoCond)
} }
if err := statement.ProcessIDParam(); err != nil { return statement.ProcessIDParam()
return err
}
return nil
} }
// GenConds generates conditions
func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) { func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) {
if err := statement.mergeConds(bean); err != nil { if err := statement.mergeConds(bean); err != nil {
return "", nil, err return "", nil, err
@ -930,6 +938,7 @@ func (statement *Statement) quoteColumnStr(columnStr string) string {
return statement.dialect.Quoter().Join(columns, ",") return statement.dialect.Quoter().Join(columns, ",")
} }
// ConvertSQLOrArgs converts sql or args
func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) {
sql, args, err := convertSQLOrArgs(sqlOrArgs...) sql, args, err := convertSQLOrArgs(sqlOrArgs...)
if err != nil { if err != nil {

View File

@ -77,6 +77,7 @@ func convertArg(arg interface{}, convertFunc func(string) string) string {
const insertSelectPlaceHolder = true const insertSelectPlaceHolder = true
// WriteArg writes an arg
func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error {
switch argv := arg.(type) { switch argv := arg.(type) {
case *builder.Builder: case *builder.Builder:
@ -116,6 +117,7 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er
return nil return nil
} }
// WriteArgs writes args
func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error {
for i, arg := range args { for i, arg := range args {
if err := statement.WriteArg(w, arg); err != nil { if err := statement.WriteArg(w, arg); err != nil {

View File

@ -16,6 +16,7 @@ type Mapper interface {
Table2Obj(string) string Table2Obj(string) string
} }
// CacheMapper represents a cache mapper
type CacheMapper struct { type CacheMapper struct {
oriMapper Mapper oriMapper Mapper
obj2tableCache map[string]string obj2tableCache map[string]string
@ -24,12 +25,14 @@ type CacheMapper struct {
table2objMutex sync.RWMutex table2objMutex sync.RWMutex
} }
// NewCacheMapper creates a cache mapper
func NewCacheMapper(mapper Mapper) *CacheMapper { func NewCacheMapper(mapper Mapper) *CacheMapper {
return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string),
table2objCache: make(map[string]string), table2objCache: make(map[string]string),
} }
} }
// Obj2Table implements Mapper
func (m *CacheMapper) Obj2Table(o string) string { func (m *CacheMapper) Obj2Table(o string) string {
m.obj2tableMutex.RLock() m.obj2tableMutex.RLock()
t, ok := m.obj2tableCache[o] t, ok := m.obj2tableCache[o]
@ -45,6 +48,7 @@ func (m *CacheMapper) Obj2Table(o string) string {
return t return t
} }
// Table2Obj implements Mapper
func (m *CacheMapper) Table2Obj(t string) string { func (m *CacheMapper) Table2Obj(t string) string {
m.table2objMutex.RLock() m.table2objMutex.RLock()
o, ok := m.table2objCache[t] o, ok := m.table2objCache[t]
@ -60,15 +64,17 @@ func (m *CacheMapper) Table2Obj(t string) string {
return o return o
} }
// SameMapper implements IMapper and provides same name between struct and // SameMapper implements Mapper and provides same name between struct and
// database table // database table
type SameMapper struct { type SameMapper struct {
} }
// Obj2Table implements Mapper
func (m SameMapper) Obj2Table(o string) string { func (m SameMapper) Obj2Table(o string) string {
return o return o
} }
// Table2Obj implements Mapper
func (m SameMapper) Table2Obj(t string) string { func (m SameMapper) Table2Obj(t string) string {
return t return t
} }
@ -98,6 +104,7 @@ func snakeCasedName(name string) string {
return b2s(newstr) return b2s(newstr)
} }
// Obj2Table implements Mapper
func (mapper SnakeMapper) Obj2Table(name string) string { func (mapper SnakeMapper) Obj2Table(name string) string {
return snakeCasedName(name) return snakeCasedName(name)
} }
@ -127,6 +134,7 @@ func titleCasedName(name string) string {
return b2s(newstr) return b2s(newstr)
} }
// Table2Obj implements Mapper
func (mapper SnakeMapper) Table2Obj(name string) string { func (mapper SnakeMapper) Table2Obj(name string) string {
return titleCasedName(name) return titleCasedName(name)
} }
@ -168,10 +176,12 @@ func gonicCasedName(name string) string {
return strings.ToLower(string(newstr)) return strings.ToLower(string(newstr))
} }
// Obj2Table implements Mapper
func (mapper GonicMapper) Obj2Table(name string) string { func (mapper GonicMapper) Obj2Table(name string) string {
return gonicCasedName(name) return gonicCasedName(name)
} }
// Table2Obj implements Mapper
func (mapper GonicMapper) Table2Obj(name string) string { func (mapper GonicMapper) Table2Obj(name string) string {
newstr := make([]rune, 0) newstr := make([]rune, 0)
@ -234,14 +244,17 @@ type PrefixMapper struct {
Prefix string Prefix string
} }
// Obj2Table implements Mapper
func (mapper PrefixMapper) Obj2Table(name string) string { func (mapper PrefixMapper) Obj2Table(name string) string {
return mapper.Prefix + mapper.Mapper.Obj2Table(name) return mapper.Prefix + mapper.Mapper.Obj2Table(name)
} }
// Table2Obj implements Mapper
func (mapper PrefixMapper) Table2Obj(name string) string { func (mapper PrefixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
} }
// NewPrefixMapper creates a prefix mapper
func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper { func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper {
return PrefixMapper{mapper, prefix} return PrefixMapper{mapper, prefix}
} }
@ -252,14 +265,17 @@ type SuffixMapper struct {
Suffix string Suffix string
} }
// Obj2Table implements Mapper
func (mapper SuffixMapper) Obj2Table(name string) string { func (mapper SuffixMapper) Obj2Table(name string) string {
return mapper.Mapper.Obj2Table(name) + mapper.Suffix return mapper.Mapper.Obj2Table(name) + mapper.Suffix
} }
// Table2Obj implements Mapper
func (mapper SuffixMapper) Table2Obj(name string) string { func (mapper SuffixMapper) Table2Obj(name string) string {
return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)])
} }
// NewSuffixMapper creates a suffix mapper
func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper { func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper {
return SuffixMapper{mapper, suffix} return SuffixMapper{mapper, suffix}
} }

View File

@ -19,6 +19,7 @@ var (
tvCache sync.Map tvCache sync.Map
) )
// GetTableName returns table name
func GetTableName(mapper Mapper, v reflect.Value) string { func GetTableName(mapper Mapper, v reflect.Value) string {
if v.Type().Implements(tpTableName) { if v.Type().Implements(tpTableName) {
return v.Interface().(TableName).TableName() return v.Interface().(TableName).TableName()

View File

@ -13,6 +13,7 @@ import (
"time" "time"
) )
// enumerates all database mapping way
const ( const (
TWOSIDES = iota + 1 TWOSIDES = iota + 1
ONLYTODB ONLYTODB

View File

@ -28,6 +28,7 @@ func NewIndex(name string, indexType int) *Index {
return &Index{true, name, indexType, make([]string, 0)} return &Index{true, name, indexType, make([]string, 0)}
} }
// XName returns the special index name for the table
func (index *Index) XName(tableName string) string { func (index *Index) XName(tableName string) string {
if !strings.HasPrefix(index.Name, "UQE_") && if !strings.HasPrefix(index.Name, "UQE_") &&
!strings.HasPrefix(index.Name, "IDX_") { !strings.HasPrefix(index.Name, "IDX_") {
@ -43,11 +44,10 @@ func (index *Index) XName(tableName string) string {
// AddColumn add columns which will be composite index // AddColumn add columns which will be composite index
func (index *Index) AddColumn(cols ...string) { func (index *Index) AddColumn(cols ...string) {
for _, col := range cols { index.Cols = append(index.Cols, cols...)
index.Cols = append(index.Cols, col)
}
} }
// Equal return true if the two Index is equal
func (index *Index) Equal(dst *Index) bool { func (index *Index) Equal(dst *Index) bool {
if index.Type != dst.Type { if index.Type != dst.Type {
return false return false

View File

@ -11,13 +11,16 @@ import (
"xorm.io/xorm/internal/utils" "xorm.io/xorm/internal/utils"
) )
// PK represents primary key values
type PK []interface{} type PK []interface{}
// NewPK creates primay keys
func NewPK(pks ...interface{}) *PK { func NewPK(pks ...interface{}) *PK {
p := PK(pks) p := PK(pks)
return &p return &p
} }
// IsZero return true if primay keys are zero
func (p *PK) IsZero() bool { func (p *PK) IsZero() bool {
for _, k := range *p { for _, k := range *p {
if utils.IsZero(k) { if utils.IsZero(k) {
@ -27,6 +30,7 @@ func (p *PK) IsZero() bool {
return false return false
} }
// ToString convert to SQL string
func (p *PK) ToString() (string, error) { func (p *PK) ToString() (string, error) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf) enc := gob.NewEncoder(buf)
@ -34,6 +38,7 @@ func (p *PK) ToString() (string, error) {
return buf.String(), err return buf.String(), err
} }
// FromString reads content to load primary keys
func (p *PK) FromString(content string) error { func (p *PK) FromString(content string) error {
dec := gob.NewDecoder(bytes.NewBufferString(content)) dec := gob.NewDecoder(bytes.NewBufferString(content))
err := dec.Decode(p) err := dec.Decode(p)

View File

@ -16,10 +16,10 @@ type Quoter struct {
} }
var ( var (
// AlwaysFalseReverse always think it's not a reverse word // AlwaysNoReserve always think it's not a reverse word
AlwaysNoReserve = func(string) bool { return false } AlwaysNoReserve = func(string) bool { return false }
// AlwaysReverse always reverse the word // AlwaysReserve always reverse the word
AlwaysReserve = func(string) bool { return true } AlwaysReserve = func(string) bool { return true }
// CommanQuoteMark represnets the common quote mark // CommanQuoteMark represnets the common quote mark
@ -29,10 +29,12 @@ var (
CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve}
) )
// IsEmpty return true if no prefix and suffix
func (q Quoter) IsEmpty() bool { func (q Quoter) IsEmpty() bool {
return q.Prefix == 0 && q.Suffix == 0 return q.Prefix == 0 && q.Suffix == 0
} }
// Quote quote a string
func (q Quoter) Quote(s string) string { func (q Quoter) Quote(s string) string {
var buf strings.Builder var buf strings.Builder
q.QuoteTo(&buf, s) q.QuoteTo(&buf, s)
@ -59,12 +61,14 @@ func (q Quoter) Trim(s string) string {
return buf.String() return buf.String()
} }
// Join joins a slice with quoters
func (q Quoter) Join(a []string, sep string) string { func (q Quoter) Join(a []string, sep string) string {
var b strings.Builder var b strings.Builder
q.JoinWrite(&b, a, sep) q.JoinWrite(&b, a, sep)
return b.String() return b.String()
} }
// JoinWrite writes quoted content to a builder
func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error {
if len(a) == 0 { if len(a) == 0 {
return nil return nil

View File

@ -90,23 +90,28 @@ func (table *Table) PKColumns() []*Column {
return columns return columns
} }
// ColumnType returns a column's type
func (table *Table) ColumnType(name string) reflect.Type { func (table *Table) ColumnType(name string) reflect.Type {
t, _ := table.Type.FieldByName(name) t, _ := table.Type.FieldByName(name)
return t.Type return t.Type
} }
// AutoIncrColumn returns autoincrement column
func (table *Table) AutoIncrColumn() *Column { func (table *Table) AutoIncrColumn() *Column {
return table.GetColumn(table.AutoIncrement) return table.GetColumn(table.AutoIncrement)
} }
// VersionColumn returns version column's information
func (table *Table) VersionColumn() *Column { func (table *Table) VersionColumn() *Column {
return table.GetColumn(table.Version) return table.GetColumn(table.Version)
} }
// UpdatedColumn returns updated column's information
func (table *Table) UpdatedColumn() *Column { func (table *Table) UpdatedColumn() *Column {
return table.GetColumn(table.Updated) return table.GetColumn(table.Updated)
} }
// DeletedColumn returns deleted column's information
func (table *Table) DeletedColumn() *Column { func (table *Table) DeletedColumn() *Column {
return table.GetColumn(table.Deleted) return table.GetColumn(table.Deleted)
} }

View File

@ -11,8 +11,10 @@ import (
"time" "time"
) )
// DBType represents a database type
type DBType string type DBType string
// enumerates all database types
const ( const (
POSTGRES DBType = "postgres" POSTGRES DBType = "postgres"
SQLITE DBType = "sqlite3" SQLITE DBType = "sqlite3"
@ -28,6 +30,7 @@ type SQLType struct {
DefaultLength2 int DefaultLength2 int
} }
// enumerates all columns types
const ( const (
UNKNOW_TYPE = iota UNKNOW_TYPE = iota
TEXT_TYPE TEXT_TYPE
@ -37,6 +40,7 @@ const (
ARRAY_TYPE ARRAY_TYPE
) )
// IsType reutrns ture if the column type is the same as the parameter
func (s *SQLType) IsType(st int) bool { func (s *SQLType) IsType(st int) bool {
if t, ok := SqlTypes[s.Name]; ok && t == st { if t, ok := SqlTypes[s.Name]; ok && t == st {
return true return true
@ -44,34 +48,42 @@ func (s *SQLType) IsType(st int) bool {
return false return false
} }
// IsText returns true if column is a text type
func (s *SQLType) IsText() bool { func (s *SQLType) IsText() bool {
return s.IsType(TEXT_TYPE) return s.IsType(TEXT_TYPE)
} }
// IsBlob returns true if column is a binary type
func (s *SQLType) IsBlob() bool { func (s *SQLType) IsBlob() bool {
return s.IsType(BLOB_TYPE) return s.IsType(BLOB_TYPE)
} }
// IsTime returns true if column is a time type
func (s *SQLType) IsTime() bool { func (s *SQLType) IsTime() bool {
return s.IsType(TIME_TYPE) return s.IsType(TIME_TYPE)
} }
// IsNumeric returns true if column is a numeric type
func (s *SQLType) IsNumeric() bool { func (s *SQLType) IsNumeric() bool {
return s.IsType(NUMERIC_TYPE) return s.IsType(NUMERIC_TYPE)
} }
// IsArray returns true if column is an array type
func (s *SQLType) IsArray() bool { func (s *SQLType) IsArray() bool {
return s.IsType(ARRAY_TYPE) return s.IsType(ARRAY_TYPE)
} }
// IsJson returns true if column is an array type
func (s *SQLType) IsJson() bool { func (s *SQLType) IsJson() bool {
return s.Name == Json || s.Name == Jsonb return s.Name == Json || s.Name == Jsonb
} }
// IsXML returns true if column is an xml type
func (s *SQLType) IsXML() bool { func (s *SQLType) IsXML() bool {
return s.Name == XML return s.Name == XML
} }
// enumerates all the database column types
var ( var (
Bit = "BIT" Bit = "BIT"
UnsignedBit = "UNSIGNED BIT" UnsignedBit = "UNSIGNED BIT"
@ -210,53 +222,55 @@ var (
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison // !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison
var ( var (
c_EMPTY_STRING string emptyString string
c_BOOL_DEFAULT bool boolDefault bool
c_BYTE_DEFAULT byte byteDefault byte
c_COMPLEX64_DEFAULT complex64 complex64Default complex64
c_COMPLEX128_DEFAULT complex128 complex128Default complex128
c_FLOAT32_DEFAULT float32 float32Default float32
c_FLOAT64_DEFAULT float64 float64Default float64
c_INT64_DEFAULT int64 int64Default int64
c_UINT64_DEFAULT uint64 uint64Default uint64
c_INT32_DEFAULT int32 int32Default int32
c_UINT32_DEFAULT uint32 uint32Default uint32
c_INT16_DEFAULT int16 int16Default int16
c_UINT16_DEFAULT uint16 uint16Default uint16
c_INT8_DEFAULT int8 int8Default int8
c_UINT8_DEFAULT uint8 uint8Default uint8
c_INT_DEFAULT int intDefault int
c_UINT_DEFAULT uint uintDefault uint
c_TIME_DEFAULT time.Time timeDefault time.Time
) )
// enumerates all types
var ( var (
IntType = reflect.TypeOf(c_INT_DEFAULT) IntType = reflect.TypeOf(intDefault)
Int8Type = reflect.TypeOf(c_INT8_DEFAULT) Int8Type = reflect.TypeOf(int8Default)
Int16Type = reflect.TypeOf(c_INT16_DEFAULT) Int16Type = reflect.TypeOf(int16Default)
Int32Type = reflect.TypeOf(c_INT32_DEFAULT) Int32Type = reflect.TypeOf(int32Default)
Int64Type = reflect.TypeOf(c_INT64_DEFAULT) Int64Type = reflect.TypeOf(int64Default)
UintType = reflect.TypeOf(c_UINT_DEFAULT) UintType = reflect.TypeOf(uintDefault)
Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) Uint8Type = reflect.TypeOf(uint8Default)
Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) Uint16Type = reflect.TypeOf(uint16Default)
Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) Uint32Type = reflect.TypeOf(uint32Default)
Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) Uint64Type = reflect.TypeOf(uint64Default)
Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) Float32Type = reflect.TypeOf(float32Default)
Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) Float64Type = reflect.TypeOf(float64Default)
Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) Complex64Type = reflect.TypeOf(complex64Default)
Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) Complex128Type = reflect.TypeOf(complex128Default)
StringType = reflect.TypeOf(c_EMPTY_STRING) StringType = reflect.TypeOf(emptyString)
BoolType = reflect.TypeOf(c_BOOL_DEFAULT) BoolType = reflect.TypeOf(boolDefault)
ByteType = reflect.TypeOf(c_BYTE_DEFAULT) ByteType = reflect.TypeOf(byteDefault)
BytesType = reflect.SliceOf(ByteType) BytesType = reflect.SliceOf(ByteType)
TimeType = reflect.TypeOf(c_TIME_DEFAULT) TimeType = reflect.TypeOf(timeDefault)
) )
// enumerates all types
var ( var (
PtrIntType = reflect.PtrTo(IntType) PtrIntType = reflect.PtrTo(IntType)
PtrInt8Type = reflect.PtrTo(Int8Type) PtrInt8Type = reflect.PtrTo(Int8Type)
@ -301,7 +315,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
case reflect.Complex64, reflect.Complex128: case reflect.Complex64, reflect.Complex128:
st = SQLType{Varchar, 64, 0} st = SQLType{Varchar, 64, 0}
case reflect.Array, reflect.Slice, reflect.Map: case reflect.Array, reflect.Slice, reflect.Map:
if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { if t.Elem() == reflect.TypeOf(byteDefault) {
st = SQLType{Blob, 0, 0} st = SQLType{Blob, 0, 0}
} else { } else {
st = SQLType{Text, 0, 0} st = SQLType{Text, 0, 0}
@ -325,7 +339,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) {
return return
} }
// default sql type change to go types // SQLType2Type convert default sql type change to go types
func SQLType2Type(st SQLType) reflect.Type { func SQLType2Type(st SQLType) reflect.Type {
name := strings.ToUpper(st.Name) name := strings.ToUpper(st.Name)
switch name { switch name {
@ -344,7 +358,7 @@ func SQLType2Type(st SQLType) reflect.Type {
case Bool: case Bool:
return reflect.TypeOf(true) return reflect.TypeOf(true)
case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year:
return reflect.TypeOf(c_TIME_DEFAULT) return reflect.TypeOf(timeDefault)
case Decimal, Numeric, Money, SmallMoney: case Decimal, Numeric, Money, SmallMoney:
return reflect.TypeOf("") return reflect.TypeOf("")
default: default:

View File

@ -21,9 +21,11 @@ import (
) )
var ( var (
// ErrUnsupportedType represents an unsupported type error
ErrUnsupportedType = errors.New("Unsupported type") ErrUnsupportedType = errors.New("Unsupported type")
) )
// Parser represents a parser for xorm tag
type Parser struct { type Parser struct {
identifier string identifier string
dialect dialects.Dialect dialect dialects.Dialect
@ -34,6 +36,7 @@ type Parser struct {
tableCache sync.Map // map[reflect.Type]*schemas.Table tableCache sync.Map // map[reflect.Type]*schemas.Table
} }
// NewParser creates a tag parser
func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser {
return &Parser{ return &Parser{
identifier: identifier, identifier: identifier,
@ -45,29 +48,35 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM
} }
} }
// GetTableMapper returns table mapper
func (parser *Parser) GetTableMapper() names.Mapper { func (parser *Parser) GetTableMapper() names.Mapper {
return parser.tableMapper return parser.tableMapper
} }
// SetTableMapper sets table mapper
func (parser *Parser) SetTableMapper(mapper names.Mapper) { func (parser *Parser) SetTableMapper(mapper names.Mapper) {
parser.ClearCaches() parser.ClearCaches()
parser.tableMapper = mapper parser.tableMapper = mapper
} }
// GetColumnMapper returns column mapper
func (parser *Parser) GetColumnMapper() names.Mapper { func (parser *Parser) GetColumnMapper() names.Mapper {
return parser.columnMapper return parser.columnMapper
} }
// SetColumnMapper sets column mapper
func (parser *Parser) SetColumnMapper(mapper names.Mapper) { func (parser *Parser) SetColumnMapper(mapper names.Mapper) {
parser.ClearCaches() parser.ClearCaches()
parser.columnMapper = mapper parser.columnMapper = mapper
} }
// SetIdentifier sets tag identifier
func (parser *Parser) SetIdentifier(identifier string) { func (parser *Parser) SetIdentifier(identifier string) {
parser.ClearCaches() parser.ClearCaches()
parser.identifier = identifier parser.identifier = identifier
} }
// ParseWithCache parse a struct with cache
func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) {
t := v.Type() t := v.Type()
tableI, ok := parser.tableCache.Load(t) tableI, ok := parser.tableCache.Load(t)