diff --git a/.revive.toml b/.revive.toml index 64e223bb..6dec7465 100644 --- a/.revive.toml +++ b/.revive.toml @@ -15,6 +15,7 @@ warningCode = 1 [rule.if-return] [rule.increment-decrement] [rule.var-naming] + arguments = [["ID", "UID", "UUID", "URL", "JSON"], []] [rule.var-declaration] [rule.package-comments] [rule.range] @@ -22,4 +23,5 @@ warningCode = 1 [rule.time-naming] [rule.unexported-return] [rule.indent-error-flow] -[rule.errorf] \ No newline at end of file +[rule.errorf] +[rule.struct-tag] \ No newline at end of file diff --git a/caches/encode.go b/caches/encode.go index 4ba39924..95536d7e 100644 --- a/caches/encode.go +++ b/caches/encode.go @@ -13,22 +13,26 @@ import ( "io" ) -// md5 hash string +// Md5 return md5 hash string func Md5(str string) string { m := md5.New() io.WriteString(m, str) return fmt.Sprintf("%x", m.Sum(nil)) } + +// Encode Encode data func Encode(data interface{}) ([]byte, error) { //return JsonEncode(data) return GobEncode(data) } +// Decode decode data func Decode(data []byte, to interface{}) error { //return JsonDecode(data, to) return GobDecode(data, to) } +// GobEncode encode data with gob func GobEncode(data interface{}) ([]byte, error) { var buf bytes.Buffer enc := gob.NewEncoder(&buf) @@ -39,12 +43,14 @@ func GobEncode(data interface{}) ([]byte, error) { return buf.Bytes(), nil } +// GobDecode decode data with gob func GobDecode(data []byte, to interface{}) error { buf := bytes.NewBuffer(data) dec := gob.NewDecoder(buf) return dec.Decode(to) } +// JsonEncode encode data with json func JsonEncode(data interface{}) ([]byte, error) { val, err := json.Marshal(data) if err != nil { @@ -53,6 +59,7 @@ func JsonEncode(data interface{}) ([]byte, error) { return val, nil } +// JsonDecode decode data with json func JsonDecode(data []byte, to interface{}) error { return json.Unmarshal(data, to) } diff --git a/core/db.go b/core/db.go index 50c64c6f..ef5ab227 100644 --- a/core/db.go +++ b/core/db.go @@ -23,6 +23,7 @@ var ( DefaultCacheSize = 200 ) +// MapToSlice map query and struct as sql and args func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -44,6 +45,7 @@ func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { return query, args, err } +// StructToSlice converts a query and struct as sql and args func StructToSlice(query string, st interface{}) (string, []interface{}, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -176,6 +178,7 @@ func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { return db.QueryMapContext(context.Background(), query, mp) } +// QueryStructContext query rows with struct func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -184,10 +187,12 @@ func (db *DB) QueryStructContext(ctx context.Context, query string, st interface return db.QueryContext(ctx, query, args...) } +// QueryStruct query rows with struct func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { return db.QueryStructContext(context.Background(), query, st) } +// QueryRowContext query row with args func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := db.QueryContext(ctx, query, args...) if err != nil { @@ -196,10 +201,12 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interfa return &Row{rows, nil} } +// QueryRow query row with args func (db *DB) QueryRow(query string, args ...interface{}) *Row { return db.QueryRowContext(context.Background(), query, args...) } +// QueryRowMapContext query row with map func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { @@ -208,10 +215,12 @@ func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface return db.QueryRowContext(ctx, query, args...) } +// QueryRowMap query row with map func (db *DB) QueryRowMap(query string, mp interface{}) *Row { return db.QueryRowMapContext(context.Background(), query, mp) } +// QueryRowStructContext query row with struct func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { @@ -220,6 +229,7 @@ func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interf return db.QueryRowContext(ctx, query, args...) } +// QueryRowStruct query row with struct func (db *DB) QueryRowStruct(query string, st interface{}) *Row { return db.QueryRowStructContext(context.Background(), query, st) } @@ -239,10 +249,12 @@ func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) return db.ExecContext(ctx, query, args...) } +// ExecMap exec query with map func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { return db.ExecMapContext(context.Background(), query, mp) } +// ExecStructContext exec query with map func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -251,6 +263,7 @@ func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{ return db.ExecContext(ctx, query, args...) } +// ExecContext exec query with args func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := db.beforeProcess(hookCtx) @@ -265,6 +278,7 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{} return res, nil } +// ExecStruct exec query with struct func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { return db.ExecStructContext(context.Background(), query, st) } @@ -288,6 +302,7 @@ func (db *DB) afterProcess(c *contexts.ContextHook) error { return err } +// AddHook adds hook func (db *DB) AddHook(h ...contexts.Hook) { db.hooks.AddHook(h...) } diff --git a/core/db_test.go b/core/db_test.go index 104c5b95..e9c2d82d 100644 --- a/core/db_test.go +++ b/core/db_test.go @@ -21,7 +21,7 @@ import ( var ( dbtype = flag.String("dbtype", "sqlite3", "database type") dbConn = flag.String("dbConn", "./db_test.db", "database connect string") - createTableSql string + createTableSQL string ) func TestMain(m *testing.M) { @@ -29,12 +29,12 @@ func TestMain(m *testing.M) { switch *dbtype { case "sqlite3", "sqlite": - createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" case "mysql": fallthrough default: - createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + + createTableSQL = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" } @@ -66,7 +66,7 @@ func BenchmarkOriQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -121,7 +121,7 @@ func BenchmarkStructQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -166,7 +166,7 @@ func BenchmarkStruct2Query(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -212,7 +212,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -270,7 +270,7 @@ func BenchmarkSliceInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -321,7 +321,7 @@ func BenchmarkSliceStringQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -372,7 +372,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -426,7 +426,7 @@ func BenchmarkMapInterfaceQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -473,7 +473,7 @@ func BenchmarkMapStringQuery(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -519,7 +519,7 @@ func BenchmarkExec(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -544,7 +544,7 @@ func BenchmarkExecMap(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } @@ -577,7 +577,7 @@ func TestExecMap(t *testing.T) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { t.Error(err) } @@ -620,7 +620,7 @@ func TestExecStruct(t *testing.T) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { t.Error(err) } @@ -663,7 +663,7 @@ func BenchmarkExecStruct(b *testing.B) { } defer db.Close() - _, err = db.Exec(createTableSql) + _, err = db.Exec(createTableSQL) if err != nil { b.Error(err) } diff --git a/core/rows.go b/core/rows.go index a1e8bfbc..c15a59a3 100644 --- a/core/rows.go +++ b/core/rows.go @@ -11,11 +11,13 @@ import ( "sync" ) +// Rows represents rows of table type Rows struct { *sql.Rows db *DB } +// ToMapString returns all records func (rs *Rows) ToMapString() ([]map[string]string, error) { cols, err := rs.Columns() if err != nil { @@ -34,7 +36,7 @@ func (rs *Rows) ToMapString() ([]map[string]string, error) { return results, nil } -// scan data to a struct's pointer according field index +// ScanStructByIndex scan data to a struct's pointer according field index func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { if len(dest) == 0 { return errors.New("at least one struct") @@ -94,7 +96,7 @@ func fieldByName(v reflect.Value, name string) reflect.Value { return reflect.Zero(t) } -// scan data to a struct's pointer according field name +// ScanStructByName scan data to a struct's pointer according field name func (rs *Rows) ScanStructByName(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -120,7 +122,7 @@ func (rs *Rows) ScanStructByName(dest interface{}) error { return rs.Rows.Scan(newDest...) } -// scan data to a slice's pointer, slice's length should equal to columns' number +// ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number func (rs *Rows) ScanSlice(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { @@ -155,7 +157,7 @@ func (rs *Rows) ScanSlice(dest interface{}) error { return nil } -// scan data to a map's pointer +// ScanMap scan data to a map's pointer func (rs *Rows) ScanMap(dest interface{}) error { vv := reflect.ValueOf(dest) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -187,6 +189,7 @@ func (rs *Rows) ScanMap(dest interface{}) error { return nil } +// Row reprents a row of a tab type Row struct { rows *Rows // One of these two will be non-nil: @@ -205,6 +208,7 @@ func NewRow(rows *Rows, err error) *Row { return &Row{rows, err} } +// Columns returns all columns of the row func (row *Row) Columns() ([]string, error) { if row.err != nil { return nil, row.err @@ -212,6 +216,7 @@ func (row *Row) Columns() ([]string, error) { return row.rows.Columns() } +// Scan retrieves all row column values func (row *Row) Scan(dest ...interface{}) error { if row.err != nil { return row.err @@ -238,6 +243,7 @@ func (row *Row) Scan(dest ...interface{}) error { return row.rows.Close() } +// ScanStructByName retrieves all row column values into a struct func (row *Row) ScanStructByName(dest interface{}) error { if row.err != nil { return row.err @@ -258,6 +264,7 @@ func (row *Row) ScanStructByName(dest interface{}) error { return row.rows.Close() } +// ScanStructByIndex retrieves all row column values into a struct func (row *Row) ScanStructByIndex(dest interface{}) error { if row.err != nil { return row.err @@ -278,7 +285,7 @@ func (row *Row) ScanStructByIndex(dest interface{}) error { return row.rows.Close() } -// scan data to a slice's pointer, slice's length should equal to columns' number +// ScanSlice scan data to a slice's pointer, slice's length should equal to columns' number func (row *Row) ScanSlice(dest interface{}) error { if row.err != nil { return row.err @@ -300,7 +307,7 @@ func (row *Row) ScanSlice(dest interface{}) error { return row.rows.Close() } -// scan data to a map's pointer +// ScanMap scan data to a map's pointer func (row *Row) ScanMap(dest interface{}) error { if row.err != nil { return row.err @@ -322,6 +329,7 @@ func (row *Row) ScanMap(dest interface{}) error { return row.rows.Close() } +// ToMapString returns all clumns of this record func (row *Row) ToMapString() (map[string]string, error) { cols, err := row.Columns() if err != nil { diff --git a/core/scan.go b/core/scan.go index 897b5341..1e7e4525 100644 --- a/core/scan.go +++ b/core/scan.go @@ -10,12 +10,14 @@ import ( "time" ) +// NullTime defines a customize type NullTime type NullTime time.Time var ( _ driver.Valuer = NullTime{} ) +// Scan implements driver.Valuer func (ns *NullTime) Scan(value interface{}) error { if value == nil { return nil @@ -58,9 +60,11 @@ func convertTime(dest *NullTime, src interface{}) error { return nil } +// EmptyScanner represents an empty scanner type EmptyScanner struct { } +// Scan implements func (EmptyScanner) Scan(src interface{}) error { return nil } diff --git a/core/stmt.go b/core/stmt.go index d46ac9c6..260843d5 100644 --- a/core/stmt.go +++ b/core/stmt.go @@ -21,6 +21,7 @@ type Stmt struct { query string } +// PrepareContext creates a prepare statement func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { names := make(map[string]int) var i int @@ -42,10 +43,12 @@ func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return &Stmt{stmt, db, names, query}, nil } +// Prepare creates a prepare statement func (db *DB) Prepare(query string) (*Stmt, error) { return db.PrepareContext(context.Background(), query) } +// ExecMapContext execute with map func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -59,10 +62,12 @@ func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, return s.ExecContext(ctx, args...) } +// ExecMap executes with map func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { return s.ExecMapContext(context.Background(), mp) } +// ExecStructContext executes with struct func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -76,10 +81,12 @@ func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Resul return s.ExecContext(ctx, args...) } +// ExecStruct executes with struct func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { return s.ExecStructContext(context.Background(), st) } +// ExecContext with args func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, s.query, args) ctx, err := s.db.beforeProcess(hookCtx) @@ -94,6 +101,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result return res, nil } +// QueryContext query with args func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { hookCtx := contexts.NewContextHook(ctx, s.query, args) ctx, err := s.db.beforeProcess(hookCtx) @@ -108,10 +116,12 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return &Rows{rows, s.db}, nil } +// Query query with args func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } +// QueryMapContext query with map func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -126,10 +136,12 @@ func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, erro return s.QueryContext(ctx, args...) } +// QueryMap query with map func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { return s.QueryMapContext(context.Background(), mp) } +// QueryStructContext query with struct func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -144,19 +156,23 @@ func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, e return s.QueryContext(ctx, args...) } +// QueryStruct query with struct func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { return s.QueryStructContext(context.Background(), st) } +// QueryRowContext query row with args func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { rows, err := s.QueryContext(ctx, args...) return &Row{rows, err} } +// QueryRow query row with args func (s *Stmt) QueryRow(args ...interface{}) *Row { return s.QueryRowContext(context.Background(), args...) } +// QueryRowMapContext query row with map func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { @@ -171,10 +187,12 @@ func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { return s.QueryRowContext(ctx, args...) } +// QueryRowMap query row with map func (s *Stmt) QueryRowMap(mp interface{}) *Row { return s.QueryRowMapContext(context.Background(), mp) } +// QueryRowStructContext query row with struct func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { @@ -189,6 +207,7 @@ func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { return s.QueryRowContext(ctx, args...) } +// QueryRowStruct query row with struct func (s *Stmt) QueryRowStruct(st interface{}) *Row { return s.QueryRowStructContext(context.Background(), st) } diff --git a/core/tx.go b/core/tx.go index a85a6874..a2f745f8 100644 --- a/core/tx.go +++ b/core/tx.go @@ -22,6 +22,7 @@ type Tx struct { ctx context.Context } +// BeginTx begin a transaction with option func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { hookCtx := contexts.NewContextHook(ctx, "BEGIN TRANSACTION", nil) ctx, err := db.beforeProcess(hookCtx) @@ -36,10 +37,12 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { return &Tx{tx, db, ctx}, nil } +// Begin begins a transaction func (db *DB) Begin() (*Tx, error) { return db.BeginTx(context.Background(), nil) } +// Commit submit the transaction func (tx *Tx) Commit() error { hookCtx := contexts.NewContextHook(tx.ctx, "COMMIT", nil) ctx, err := tx.db.beforeProcess(hookCtx) @@ -48,12 +51,10 @@ func (tx *Tx) Commit() error { } err = tx.Tx.Commit() hookCtx.End(ctx, nil, err) - if err := tx.db.afterProcess(hookCtx); err != nil { - return err - } - return nil + return tx.db.afterProcess(hookCtx) } +// Rollback rollback the transaction func (tx *Tx) Rollback() error { hookCtx := contexts.NewContextHook(tx.ctx, "ROLLBACK", nil) ctx, err := tx.db.beforeProcess(hookCtx) @@ -62,12 +63,10 @@ func (tx *Tx) Rollback() error { } err = tx.Tx.Rollback() hookCtx.End(ctx, nil, err) - if err := tx.db.afterProcess(hookCtx); err != nil { - return err - } - return nil + return tx.db.afterProcess(hookCtx) } +// PrepareContext prepare the query func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { names := make(map[string]int) var i int @@ -89,19 +88,23 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { return &Stmt{stmt, tx.db, names, query}, nil } +// Prepare prepare the query func (tx *Tx) Prepare(query string) (*Stmt, error) { return tx.PrepareContext(context.Background(), query) } +// StmtContext creates Stmt with context func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt) return stmt } +// Stmt creates Stmt func (tx *Tx) Stmt(stmt *Stmt) *Stmt { return tx.StmtContext(context.Background(), stmt) } +// ExecMapContext executes query with args in a map func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { query, args, err := MapToSlice(query, mp) if err != nil { @@ -110,10 +113,12 @@ func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) return tx.ExecContext(ctx, query, args...) } +// ExecMap executes query with args in a map func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { return tx.ExecMapContext(context.Background(), query, mp) } +// ExecStructContext executes query with args in a struct func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -122,6 +127,7 @@ func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{ return tx.ExecContext(ctx, query, args...) } +// ExecContext executes a query with args func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := tx.db.beforeProcess(hookCtx) @@ -136,10 +142,12 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{} return res, err } +// ExecStruct executes query with args in a struct func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { return tx.ExecStructContext(context.Background(), query, st) } +// QueryContext query with args func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { hookCtx := contexts.NewContextHook(ctx, query, args) ctx, err := tx.db.beforeProcess(hookCtx) @@ -157,10 +165,12 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{ return &Rows{rows, tx.db}, nil } +// Query query with args func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { return tx.QueryContext(context.Background(), query, args...) } +// QueryMapContext query with args in a map func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { query, args, err := MapToSlice(query, mp) if err != nil { @@ -169,10 +179,12 @@ func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) return tx.QueryContext(ctx, query, args...) } +// QueryMap query with args in a map func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { return tx.QueryMapContext(context.Background(), query, mp) } +// QueryStructContext query with args in struct func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { query, args, err := StructToSlice(query, st) if err != nil { @@ -181,19 +193,23 @@ func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface return tx.QueryContext(ctx, query, args...) } +// QueryStruct query with args in struct func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { return tx.QueryStructContext(context.Background(), query, st) } +// QueryRowContext query one row with args func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { rows, err := tx.QueryContext(ctx, query, args...) return &Row{rows, err} } +// QueryRow query one row with args func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { return tx.QueryRowContext(context.Background(), query, args...) } +// QueryRowMapContext query one row with args in a map func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { @@ -202,10 +218,12 @@ func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface return tx.QueryRowContext(ctx, query, args...) } +// QueryRowMap query one row with args in a map func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { return tx.QueryRowMapContext(context.Background(), query, mp) } +// QueryRowStructContext query one row with args in struct func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { @@ -214,6 +232,7 @@ func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interf return tx.QueryRowContext(ctx, query, args...) } +// QueryRowStruct query one row with args in struct func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { return tx.QueryRowStructContext(context.Background(), query, st) } diff --git a/dialects/driver.go b/dialects/driver.go index ae3afe42..bb46a936 100644 --- a/dialects/driver.go +++ b/dialects/driver.go @@ -8,6 +8,7 @@ import ( "fmt" ) +// Driver represents a database driver type Driver interface { Parse(string, string) (*URI, error) } @@ -16,6 +17,7 @@ var ( drivers = map[string]Driver{} ) +// RegisterDriver register a driver func RegisterDriver(driverName string, driver Driver) { if driver == nil { panic("core: Register driver is nil") @@ -26,10 +28,12 @@ func RegisterDriver(driverName string, driver Driver) { drivers[driverName] = driver } +// QueryDriver query a driver with name func QueryDriver(driverName string) Driver { return drivers[driverName] } +// RegisteredDriverSize returned all drivers's length func RegisteredDriverSize() int { return len(drivers) } @@ -38,7 +42,7 @@ func RegisteredDriverSize() int { func OpenDialect(driverName, connstr string) (Dialect, error) { driver := QueryDriver(driverName) if driver == nil { - return nil, fmt.Errorf("Unsupported driver name: %v", driverName) + return nil, fmt.Errorf("unsupported driver name: %v", driverName) } uri, err := driver.Parse(driverName, connstr) @@ -48,7 +52,7 @@ func OpenDialect(driverName, connstr string) (Dialect, error) { dialect := QueryDialect(uri.DBType) if dialect == nil { - return nil, fmt.Errorf("Unsupported dialect type: %v", uri.DBType) + return nil, fmt.Errorf("unsupported dialect type: %v", uri.DBType) } dialect.Init(uri) diff --git a/dialects/filter.go b/dialects/filter.go index 6968b6ce..2a36a731 100644 --- a/dialects/filter.go +++ b/dialects/filter.go @@ -38,6 +38,7 @@ func convertQuestionMark(sql, prefix string, start int) string { return buf.String() } +// Do implements Filter func (s *SeqFilter) Do(sql string) string { return convertQuestionMark(sql, s.Prefix, s.Start) } diff --git a/dialects/time.go b/dialects/time.go index 9a3c82a4..5aee0c10 100644 --- a/dialects/time.go +++ b/dialects/time.go @@ -38,6 +38,7 @@ func FormatTime(dialect Dialect, sqlTypeName string, t time.Time) (v interface{} return } +// FormatColumnTime format column time func FormatColumnTime(dialect Dialect, defaultTimeZone *time.Location, col *schemas.Column, t time.Time) (v interface{}) { if t.IsZero() { if col.Nullable { diff --git a/integrations/session_delete_test.go b/integrations/session_delete_test.go index f3565963..cc7e861d 100644 --- a/integrations/session_delete_test.go +++ b/integrations/session_delete_test.go @@ -97,6 +97,7 @@ func TestDeleted(t *testing.T) { // Test normal Find() var records1 []Deleted err = testEngine.Where("`"+testEngine.GetColumnMapper().Obj2Table("Id")+"` > 0").Find(&records1, &Deleted{}) + assert.NoError(t, err) assert.EqualValues(t, 3, len(records1)) // Test normal Get() @@ -132,6 +133,7 @@ func TestDeleted(t *testing.T) { record2 := &Deleted{} has, err = testEngine.ID(2).Get(record2) assert.NoError(t, err) + assert.True(t, has) assert.True(t, record2.DeletedAt.IsZero()) // Test find all records whatever `deleted`. diff --git a/integrations/tests.go b/integrations/tests.go index 512f3962..8b14b0f4 100644 --- a/integrations/tests.go +++ b/integrations/tests.go @@ -166,10 +166,7 @@ func createEngine(dbType, connStr string) error { for _, table := range tables { tableNames = append(tableNames, table.Name) } - if err = testEngine.DropTables(tableNames...); err != nil { - return err - } - return nil + return testEngine.DropTables(tableNames...) } // PrepareEngine prepare tests ORM engine diff --git a/internal/statements/cache.go b/internal/statements/cache.go index cb33df08..669cd018 100644 --- a/internal/statements/cache.go +++ b/internal/statements/cache.go @@ -12,6 +12,7 @@ import ( "xorm.io/xorm/schemas" ) +// ConvertIDSQL converts SQL with id func (statement *Statement) ConvertIDSQL(sqlStr string) string { if statement.RefTable != nil { cols := statement.RefTable.PKColumns() @@ -37,6 +38,7 @@ func (statement *Statement) ConvertIDSQL(sqlStr string) string { return "" } +// ConvertUpdateSQL converts update SQL func (statement *Statement) ConvertUpdateSQL(sqlStr string) (string, string) { if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { return "", "" diff --git a/internal/statements/expr_param.go b/internal/statements/expr_param.go index 6657408e..d0c355d3 100644 --- a/internal/statements/expr_param.go +++ b/internal/statements/expr_param.go @@ -12,6 +12,7 @@ import ( "xorm.io/xorm/schemas" ) +// ErrUnsupportedExprType represents an error with unsupported express type type ErrUnsupportedExprType struct { tp string } diff --git a/internal/statements/query.go b/internal/statements/query.go index ab3021bf..f1b36770 100644 --- a/internal/statements/query.go +++ b/internal/statements/query.go @@ -14,6 +14,7 @@ import ( "xorm.io/xorm/schemas" ) +// GenQuerySQL generate query SQL func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []interface{}, error) { if len(sqlOrArgs) > 0 { return statement.ConvertSQLOrArgs(sqlOrArgs...) @@ -72,6 +73,7 @@ func (statement *Statement) GenQuerySQL(sqlOrArgs ...interface{}) (string, []int return sqlStr, args, nil } +// GenSumSQL generates sum SQL func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil @@ -102,6 +104,7 @@ func (statement *Statement) GenSumSQL(bean interface{}, columns ...string) (stri return sqlStr, append(statement.joinArgs, condArgs...), nil } +// GenGetSQL generates Get SQL func (statement *Statement) GenGetSQL(bean interface{}) (string, []interface{}, error) { v := rValue(bean) isStruct := v.Kind() == reflect.Struct @@ -316,6 +319,7 @@ func (statement *Statement) genSelectSQL(columnStr string, needLimit, needOrderB return buf.String(), condArgs, nil } +// GenExistSQL generates Exist SQL func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil @@ -385,6 +389,7 @@ func (statement *Statement) GenExistSQL(bean ...interface{}) (string, []interfac return sqlStr, args, nil } +// GenFindSQL generates Find SQL func (statement *Statement) GenFindSQL(autoCond builder.Cond) (string, []interface{}, error) { if statement.RawSQL != "" { return statement.GenRawSQL(), statement.RawParams, nil diff --git a/internal/statements/statement.go b/internal/statements/statement.go index a4294bec..3dd036a6 100644 --- a/internal/statements/statement.go +++ b/internal/statements/statement.go @@ -90,19 +90,17 @@ func NewStatement(dialect dialects.Dialect, tagParser *tags.Parser, defaultTimeZ return statement } +// SetTableName set table name func (statement *Statement) SetTableName(tableName string) { statement.tableName = tableName } -func (statement *Statement) omitStr() string { - return statement.dialect.Quoter().Join(statement.OmitColumnMap, " ,") -} - // GenRawSQL generates correct raw sql func (statement *Statement) GenRawSQL() string { return statement.ReplaceQuote(statement.RawSQL) } +// GenCondSQL generates condition SQL func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []interface{}, error) { condSQL, condArgs, err := builder.ToSQL(condOrBuilder) if err != nil { @@ -111,6 +109,7 @@ func (statement *Statement) GenCondSQL(condOrBuilder interface{}) (string, []int return statement.ReplaceQuote(condSQL), condArgs, nil } +// ReplaceQuote replace sql key words with quote func (statement *Statement) ReplaceQuote(sql string) string { if sql == "" || statement.dialect.URI().DBType == schemas.MYSQL || statement.dialect.URI().DBType == schemas.SQLITE { @@ -119,11 +118,12 @@ func (statement *Statement) ReplaceQuote(sql string) string { return statement.dialect.Quoter().Replace(sql) } +// SetContextCache sets context cache func (statement *Statement) SetContextCache(ctxCache contexts.ContextCache) { statement.Context = ctxCache } -// Init reset all the statement's fields +// Reset reset all the statement's fields func (statement *Statement) Reset() { statement.RefTable = nil statement.Start = 0 @@ -163,7 +163,7 @@ func (statement *Statement) Reset() { statement.LastError = nil } -// NoAutoCondition if you do not want convert bean's field as query condition, then use this function +// SetNoAutoCondition if you do not want convert bean's field as query condition, then use this function func (statement *Statement) SetNoAutoCondition(no ...bool) *Statement { statement.NoAutoCondition = true if len(no) > 0 { @@ -271,6 +271,7 @@ func (statement *Statement) NotIn(column string, args ...interface{}) *Statement return statement } +// SetRefValue set ref value func (statement *Statement) SetRefValue(v reflect.Value) error { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(reflect.Indirect(v)) @@ -285,6 +286,7 @@ func rValue(bean interface{}) reflect.Value { return reflect.Indirect(reflect.ValueOf(bean)) } +// SetRefBean set ref bean func (statement *Statement) SetRefBean(bean interface{}) error { var err error statement.RefTable, err = statement.tagParser.ParseWithCache(rValue(bean)) @@ -390,6 +392,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { return statement } +// ColumnStr returns column string func (statement *Statement) ColumnStr() string { return statement.dialect.Quoter().Join(statement.ColumnMap, ", ") } @@ -493,11 +496,12 @@ func (statement *Statement) Asc(colNames ...string) *Statement { return statement } +// Conds returns condtions func (statement *Statement) Conds() builder.Cond { return statement.cond } -// Table tempororily set table name, the parameter could be a string or a pointer of struct +// SetTable tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) SetTable(tableNameOrBean interface{}) error { v := rValue(tableNameOrBean) t := v.Type() @@ -564,7 +568,7 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition return statement } -// tbName get some table's table name +// tbNameNoSchema get some table's table name func (statement *Statement) tbNameNoSchema(table *schemas.Table) string { if len(statement.AltTableName) > 0 { return statement.AltTableName @@ -585,12 +589,13 @@ func (statement *Statement) Having(conditions string) *Statement { return statement } -// Unscoped always disable struct tag "deleted" +// SetUnscoped always disable struct tag "deleted" func (statement *Statement) SetUnscoped() *Statement { statement.unscoped = true return statement } +// GetUnscoped return true if it's unscoped func (statement *Statement) GetUnscoped() bool { return statement.unscoped } @@ -636,6 +641,7 @@ func (statement *Statement) genColumnStr() string { return buf.String() } +// GenCreateTableSQL generated create table SQL func (statement *Statement) GenCreateTableSQL() []string { statement.RefTable.StoreEngine = statement.StoreEngine statement.RefTable.Charset = statement.Charset @@ -643,6 +649,7 @@ func (statement *Statement) GenCreateTableSQL() []string { return s } +// GenIndexSQL generated create index SQL func (statement *Statement) GenIndexSQL() []string { var sqls []string tbName := statement.TableName() @@ -659,6 +666,7 @@ func uniqueName(tableName, uqeName string) string { return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) } +// GenUniqueSQL generates unique SQL func (statement *Statement) GenUniqueSQL() []string { var sqls []string tbName := statement.TableName() @@ -671,6 +679,7 @@ func (statement *Statement) GenUniqueSQL() []string { return sqls } +// GenDelIndexSQL generate delete index SQL func (statement *Statement) GenDelIndexSQL() []string { var sqls []string tbName := statement.TableName() @@ -896,6 +905,7 @@ func (statement *Statement) buildConds2(table *schemas.Table, bean interface{}, return builder.And(conds...), nil } +// BuildConds builds condition func (statement *Statement) BuildConds(table *schemas.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { return statement.buildConds2(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, statement.unscoped, statement.MustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) @@ -911,12 +921,10 @@ func (statement *Statement) mergeConds(bean interface{}) error { statement.cond = statement.cond.And(autoCond) } - if err := statement.ProcessIDParam(); err != nil { - return err - } - return nil + return statement.ProcessIDParam() } +// GenConds generates conditions func (statement *Statement) GenConds(bean interface{}) (string, []interface{}, error) { if err := statement.mergeConds(bean); err != nil { return "", nil, err @@ -930,6 +938,7 @@ func (statement *Statement) quoteColumnStr(columnStr string) string { return statement.dialect.Quoter().Join(columns, ",") } +// ConvertSQLOrArgs converts sql or args func (statement *Statement) ConvertSQLOrArgs(sqlOrArgs ...interface{}) (string, []interface{}, error) { sql, args, err := convertSQLOrArgs(sqlOrArgs...) if err != nil { diff --git a/internal/statements/statement_args.go b/internal/statements/statement_args.go index dc14467d..64089c1e 100644 --- a/internal/statements/statement_args.go +++ b/internal/statements/statement_args.go @@ -77,6 +77,7 @@ func convertArg(arg interface{}, convertFunc func(string) string) string { const insertSelectPlaceHolder = true +// WriteArg writes an arg func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) error { switch argv := arg.(type) { case *builder.Builder: @@ -116,6 +117,7 @@ func (statement *Statement) WriteArg(w *builder.BytesWriter, arg interface{}) er return nil } +// WriteArgs writes args func (statement *Statement) WriteArgs(w *builder.BytesWriter, args []interface{}) error { for i, arg := range args { if err := statement.WriteArg(w, arg); err != nil { diff --git a/names/mapper.go b/names/mapper.go index 79add76e..b0ce8076 100644 --- a/names/mapper.go +++ b/names/mapper.go @@ -16,6 +16,7 @@ type Mapper interface { Table2Obj(string) string } +// CacheMapper represents a cache mapper type CacheMapper struct { oriMapper Mapper obj2tableCache map[string]string @@ -24,12 +25,14 @@ type CacheMapper struct { table2objMutex sync.RWMutex } +// NewCacheMapper creates a cache mapper func NewCacheMapper(mapper Mapper) *CacheMapper { return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), table2objCache: make(map[string]string), } } +// Obj2Table implements Mapper func (m *CacheMapper) Obj2Table(o string) string { m.obj2tableMutex.RLock() t, ok := m.obj2tableCache[o] @@ -45,6 +48,7 @@ func (m *CacheMapper) Obj2Table(o string) string { return t } +// Table2Obj implements Mapper func (m *CacheMapper) Table2Obj(t string) string { m.table2objMutex.RLock() o, ok := m.table2objCache[t] @@ -60,15 +64,17 @@ func (m *CacheMapper) Table2Obj(t string) string { return o } -// SameMapper implements IMapper and provides same name between struct and +// SameMapper implements Mapper and provides same name between struct and // database table type SameMapper struct { } +// Obj2Table implements Mapper func (m SameMapper) Obj2Table(o string) string { return o } +// Table2Obj implements Mapper func (m SameMapper) Table2Obj(t string) string { return t } @@ -98,6 +104,7 @@ func snakeCasedName(name string) string { return b2s(newstr) } +// Obj2Table implements Mapper func (mapper SnakeMapper) Obj2Table(name string) string { return snakeCasedName(name) } @@ -127,6 +134,7 @@ func titleCasedName(name string) string { return b2s(newstr) } +// Table2Obj implements Mapper func (mapper SnakeMapper) Table2Obj(name string) string { return titleCasedName(name) } @@ -168,10 +176,12 @@ func gonicCasedName(name string) string { return strings.ToLower(string(newstr)) } +// Obj2Table implements Mapper func (mapper GonicMapper) Obj2Table(name string) string { return gonicCasedName(name) } +// Table2Obj implements Mapper func (mapper GonicMapper) Table2Obj(name string) string { newstr := make([]rune, 0) @@ -234,14 +244,17 @@ type PrefixMapper struct { Prefix string } +// Obj2Table implements Mapper func (mapper PrefixMapper) Obj2Table(name string) string { return mapper.Prefix + mapper.Mapper.Obj2Table(name) } +// Table2Obj implements Mapper func (mapper PrefixMapper) Table2Obj(name string) string { return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) } +// NewPrefixMapper creates a prefix mapper func NewPrefixMapper(mapper Mapper, prefix string) PrefixMapper { return PrefixMapper{mapper, prefix} } @@ -252,14 +265,17 @@ type SuffixMapper struct { Suffix string } +// Obj2Table implements Mapper func (mapper SuffixMapper) Obj2Table(name string) string { return mapper.Mapper.Obj2Table(name) + mapper.Suffix } +// Table2Obj implements Mapper func (mapper SuffixMapper) Table2Obj(name string) string { return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) } +// NewSuffixMapper creates a suffix mapper func NewSuffixMapper(mapper Mapper, suffix string) SuffixMapper { return SuffixMapper{mapper, suffix} } diff --git a/names/table_name.go b/names/table_name.go index 0afb1ae3..cc0e9274 100644 --- a/names/table_name.go +++ b/names/table_name.go @@ -19,6 +19,7 @@ var ( tvCache sync.Map ) +// GetTableName returns table name func GetTableName(mapper Mapper, v reflect.Value) string { if v.Type().Implements(tpTableName) { return v.Interface().(TableName).TableName() diff --git a/schemas/column.go b/schemas/column.go index 4f32afab..5808b84d 100644 --- a/schemas/column.go +++ b/schemas/column.go @@ -13,6 +13,7 @@ import ( "time" ) +// enumerates all database mapping way const ( TWOSIDES = iota + 1 ONLYTODB diff --git a/schemas/index.go b/schemas/index.go index 9541250f..8f31af52 100644 --- a/schemas/index.go +++ b/schemas/index.go @@ -28,6 +28,7 @@ func NewIndex(name string, indexType int) *Index { return &Index{true, name, indexType, make([]string, 0)} } +// XName returns the special index name for the table func (index *Index) XName(tableName string) string { if !strings.HasPrefix(index.Name, "UQE_") && !strings.HasPrefix(index.Name, "IDX_") { @@ -43,11 +44,10 @@ func (index *Index) XName(tableName string) string { // AddColumn add columns which will be composite index func (index *Index) AddColumn(cols ...string) { - for _, col := range cols { - index.Cols = append(index.Cols, col) - } + index.Cols = append(index.Cols, cols...) } +// Equal return true if the two Index is equal func (index *Index) Equal(dst *Index) bool { if index.Type != dst.Type { return false diff --git a/schemas/pk.go b/schemas/pk.go index 03916b44..da3c7899 100644 --- a/schemas/pk.go +++ b/schemas/pk.go @@ -11,13 +11,16 @@ import ( "xorm.io/xorm/internal/utils" ) +// PK represents primary key values type PK []interface{} +// NewPK creates primay keys func NewPK(pks ...interface{}) *PK { p := PK(pks) return &p } +// IsZero return true if primay keys are zero func (p *PK) IsZero() bool { for _, k := range *p { if utils.IsZero(k) { @@ -27,6 +30,7 @@ func (p *PK) IsZero() bool { return false } +// ToString convert to SQL string func (p *PK) ToString() (string, error) { buf := new(bytes.Buffer) enc := gob.NewEncoder(buf) @@ -34,6 +38,7 @@ func (p *PK) ToString() (string, error) { return buf.String(), err } +// FromString reads content to load primary keys func (p *PK) FromString(content string) error { dec := gob.NewDecoder(bytes.NewBufferString(content)) err := dec.Decode(p) diff --git a/schemas/quote.go b/schemas/quote.go index a0070048..71040ad9 100644 --- a/schemas/quote.go +++ b/schemas/quote.go @@ -16,10 +16,10 @@ type Quoter struct { } var ( - // AlwaysFalseReverse always think it's not a reverse word + // AlwaysNoReserve always think it's not a reverse word AlwaysNoReserve = func(string) bool { return false } - // AlwaysReverse always reverse the word + // AlwaysReserve always reverse the word AlwaysReserve = func(string) bool { return true } // CommanQuoteMark represnets the common quote mark @@ -29,10 +29,12 @@ var ( CommonQuoter = Quoter{CommanQuoteMark, CommanQuoteMark, AlwaysReserve} ) +// IsEmpty return true if no prefix and suffix func (q Quoter) IsEmpty() bool { return q.Prefix == 0 && q.Suffix == 0 } +// Quote quote a string func (q Quoter) Quote(s string) string { var buf strings.Builder q.QuoteTo(&buf, s) @@ -59,12 +61,14 @@ func (q Quoter) Trim(s string) string { return buf.String() } +// Join joins a slice with quoters func (q Quoter) Join(a []string, sep string) string { var b strings.Builder q.JoinWrite(&b, a, sep) return b.String() } +// JoinWrite writes quoted content to a builder func (q Quoter) JoinWrite(b *strings.Builder, a []string, sep string) error { if len(a) == 0 { return nil diff --git a/schemas/table.go b/schemas/table.go index 7ca9531f..bfa517aa 100644 --- a/schemas/table.go +++ b/schemas/table.go @@ -90,23 +90,28 @@ func (table *Table) PKColumns() []*Column { return columns } +// ColumnType returns a column's type func (table *Table) ColumnType(name string) reflect.Type { t, _ := table.Type.FieldByName(name) return t.Type } +// AutoIncrColumn returns autoincrement column func (table *Table) AutoIncrColumn() *Column { return table.GetColumn(table.AutoIncrement) } +// VersionColumn returns version column's information func (table *Table) VersionColumn() *Column { return table.GetColumn(table.Version) } +// UpdatedColumn returns updated column's information func (table *Table) UpdatedColumn() *Column { return table.GetColumn(table.Updated) } +// DeletedColumn returns deleted column's information func (table *Table) DeletedColumn() *Column { return table.GetColumn(table.Deleted) } diff --git a/schemas/type.go b/schemas/type.go index 6b50d184..fc02f015 100644 --- a/schemas/type.go +++ b/schemas/type.go @@ -11,8 +11,10 @@ import ( "time" ) +// DBType represents a database type type DBType string +// enumerates all database types const ( POSTGRES DBType = "postgres" SQLITE DBType = "sqlite3" @@ -28,6 +30,7 @@ type SQLType struct { DefaultLength2 int } +// enumerates all columns types const ( UNKNOW_TYPE = iota TEXT_TYPE @@ -37,6 +40,7 @@ const ( ARRAY_TYPE ) +// IsType reutrns ture if the column type is the same as the parameter func (s *SQLType) IsType(st int) bool { if t, ok := SqlTypes[s.Name]; ok && t == st { return true @@ -44,34 +48,42 @@ func (s *SQLType) IsType(st int) bool { return false } +// IsText returns true if column is a text type func (s *SQLType) IsText() bool { return s.IsType(TEXT_TYPE) } +// IsBlob returns true if column is a binary type func (s *SQLType) IsBlob() bool { return s.IsType(BLOB_TYPE) } +// IsTime returns true if column is a time type func (s *SQLType) IsTime() bool { return s.IsType(TIME_TYPE) } +// IsNumeric returns true if column is a numeric type func (s *SQLType) IsNumeric() bool { return s.IsType(NUMERIC_TYPE) } +// IsArray returns true if column is an array type func (s *SQLType) IsArray() bool { return s.IsType(ARRAY_TYPE) } +// IsJson returns true if column is an array type func (s *SQLType) IsJson() bool { return s.Name == Json || s.Name == Jsonb } +// IsXML returns true if column is an xml type func (s *SQLType) IsXML() bool { return s.Name == XML } +// enumerates all the database column types var ( Bit = "BIT" UnsignedBit = "UNSIGNED BIT" @@ -210,53 +222,55 @@ var ( // !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison var ( - c_EMPTY_STRING string - c_BOOL_DEFAULT bool - c_BYTE_DEFAULT byte - c_COMPLEX64_DEFAULT complex64 - c_COMPLEX128_DEFAULT complex128 - c_FLOAT32_DEFAULT float32 - c_FLOAT64_DEFAULT float64 - c_INT64_DEFAULT int64 - c_UINT64_DEFAULT uint64 - c_INT32_DEFAULT int32 - c_UINT32_DEFAULT uint32 - c_INT16_DEFAULT int16 - c_UINT16_DEFAULT uint16 - c_INT8_DEFAULT int8 - c_UINT8_DEFAULT uint8 - c_INT_DEFAULT int - c_UINT_DEFAULT uint - c_TIME_DEFAULT time.Time + emptyString string + boolDefault bool + byteDefault byte + complex64Default complex64 + complex128Default complex128 + float32Default float32 + float64Default float64 + int64Default int64 + uint64Default uint64 + int32Default int32 + uint32Default uint32 + int16Default int16 + uint16Default uint16 + int8Default int8 + uint8Default uint8 + intDefault int + uintDefault uint + timeDefault time.Time ) +// enumerates all types var ( - IntType = reflect.TypeOf(c_INT_DEFAULT) - Int8Type = reflect.TypeOf(c_INT8_DEFAULT) - Int16Type = reflect.TypeOf(c_INT16_DEFAULT) - Int32Type = reflect.TypeOf(c_INT32_DEFAULT) - Int64Type = reflect.TypeOf(c_INT64_DEFAULT) + IntType = reflect.TypeOf(intDefault) + Int8Type = reflect.TypeOf(int8Default) + Int16Type = reflect.TypeOf(int16Default) + Int32Type = reflect.TypeOf(int32Default) + Int64Type = reflect.TypeOf(int64Default) - UintType = reflect.TypeOf(c_UINT_DEFAULT) - Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) - Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) - Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) - Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) + UintType = reflect.TypeOf(uintDefault) + Uint8Type = reflect.TypeOf(uint8Default) + Uint16Type = reflect.TypeOf(uint16Default) + Uint32Type = reflect.TypeOf(uint32Default) + Uint64Type = reflect.TypeOf(uint64Default) - Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) - Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) + Float32Type = reflect.TypeOf(float32Default) + Float64Type = reflect.TypeOf(float64Default) - Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) - Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) + Complex64Type = reflect.TypeOf(complex64Default) + Complex128Type = reflect.TypeOf(complex128Default) - StringType = reflect.TypeOf(c_EMPTY_STRING) - BoolType = reflect.TypeOf(c_BOOL_DEFAULT) - ByteType = reflect.TypeOf(c_BYTE_DEFAULT) + StringType = reflect.TypeOf(emptyString) + BoolType = reflect.TypeOf(boolDefault) + ByteType = reflect.TypeOf(byteDefault) BytesType = reflect.SliceOf(ByteType) - TimeType = reflect.TypeOf(c_TIME_DEFAULT) + TimeType = reflect.TypeOf(timeDefault) ) +// enumerates all types var ( PtrIntType = reflect.PtrTo(IntType) PtrInt8Type = reflect.PtrTo(Int8Type) @@ -301,7 +315,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.Complex64, reflect.Complex128: st = SQLType{Varchar, 64, 0} case reflect.Array, reflect.Slice, reflect.Map: - if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { + if t.Elem() == reflect.TypeOf(byteDefault) { st = SQLType{Blob, 0, 0} } else { st = SQLType{Text, 0, 0} @@ -325,7 +339,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { return } -// default sql type change to go types +// SQLType2Type convert default sql type change to go types func SQLType2Type(st SQLType) reflect.Type { name := strings.ToUpper(st.Name) switch name { @@ -344,7 +358,7 @@ func SQLType2Type(st SQLType) reflect.Type { case Bool: return reflect.TypeOf(true) case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: - return reflect.TypeOf(c_TIME_DEFAULT) + return reflect.TypeOf(timeDefault) case Decimal, Numeric, Money, SmallMoney: return reflect.TypeOf("") default: diff --git a/tags/parser.go b/tags/parser.go index 45dd6d9d..5ad67b53 100644 --- a/tags/parser.go +++ b/tags/parser.go @@ -21,9 +21,11 @@ import ( ) var ( + // ErrUnsupportedType represents an unsupported type error ErrUnsupportedType = errors.New("Unsupported type") ) +// Parser represents a parser for xorm tag type Parser struct { identifier string dialect dialects.Dialect @@ -34,6 +36,7 @@ type Parser struct { tableCache sync.Map // map[reflect.Type]*schemas.Table } +// NewParser creates a tag parser func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnMapper names.Mapper, cacherMgr *caches.Manager) *Parser { return &Parser{ identifier: identifier, @@ -45,29 +48,35 @@ func NewParser(identifier string, dialect dialects.Dialect, tableMapper, columnM } } +// GetTableMapper returns table mapper func (parser *Parser) GetTableMapper() names.Mapper { return parser.tableMapper } +// SetTableMapper sets table mapper func (parser *Parser) SetTableMapper(mapper names.Mapper) { parser.ClearCaches() parser.tableMapper = mapper } +// GetColumnMapper returns column mapper func (parser *Parser) GetColumnMapper() names.Mapper { return parser.columnMapper } +// SetColumnMapper sets column mapper func (parser *Parser) SetColumnMapper(mapper names.Mapper) { parser.ClearCaches() parser.columnMapper = mapper } +// SetIdentifier sets tag identifier func (parser *Parser) SetIdentifier(identifier string) { parser.ClearCaches() parser.identifier = identifier } +// ParseWithCache parse a struct with cache func (parser *Parser) ParseWithCache(v reflect.Value) (*schemas.Table, error) { t := v.Type() tableI, ok := parser.tableCache.Load(t)