diff --git a/core/.circleci/config.yml b/core/.circleci/config.yml new file mode 100644 index 00000000..ce15f60e --- /dev/null +++ b/core/.circleci/config.yml @@ -0,0 +1,42 @@ +# Golang CircleCI 2.0 configuration file +# +# Check https://circleci.com/docs/2.0/language-go/ for more details +version: 2 +jobs: + build: + docker: + # specify the version + - image: circleci/golang:1.10 + + # CircleCI PostgreSQL images available at: https://hub.docker.com/r/circleci/postgres/ + - image: circleci/mysql:5.7 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: true + MYSQL_DATABASE: core_test + MYSQL_HOST: 127.0.0.1 + MYSQL_ROOT_HOST: '%' + MYSQL_USER: root + + # Specify service dependencies here if necessary + # CircleCI maintains a library of pre-built images + # documented at https://circleci.com/docs/2.0/circleci-images/ + # - image: circleci/postgres:9.4 + + #### TEMPLATE_NOTE: go expects specific checkout path representing url + #### expecting it in the form of + #### /go/src/github.com/circleci/go-tool + #### /go/src/bitbucket.org/circleci/go-tool + working_directory: /go/src/github.com/go-xorm/core + steps: + - checkout + + # specify any bash command here prefixed with `run: ` + #- run: mysql -u root -e "CREATE DATABASE core_test DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci" + - run: go get -u github.com/golang/lint/golint + - run: go get -u github.com/wadey/gocovmerge + - run: golint ./... + - run: go get -t -d -v ./... + - run: go test -v -race -coverprofile=coverage-1.txt -covermode=atomic + - run: go test -v -race --dbtype=sqlite3 -coverprofile=coverage-2.txt -covermode=atomic + - run: gocovmerge coverage-1.txt coverage-2.txt > coverage.txt + - run: bash <(curl -s https://codecov.io/bash) \ No newline at end of file diff --git a/core/.drone.yml b/core/.drone.yml new file mode 100644 index 00000000..3c118d4c --- /dev/null +++ b/core/.drone.yml @@ -0,0 +1,33 @@ +--- +kind: pipeline +name: go1.12 + +steps: + +- name: test + pull: default + image: golang:1.12 + commands: + - go vet + - "go test -v -race -coverprofile=coverage.txt -covermode=atomic -dbConn=\"root:@tcp(mysql:3306)/core_test?charset=utf8mb4\"" + environment: + GO111MODULE: "on" + GOPROXY: https://goproxy.cn + when: + event: + - push + - tag + - pull_request + +services: +- name: mysql + pull: default + image: mysql:5.7 + environment: + MYSQL_ALLOW_EMPTY_PASSWORD: yes + MYSQL_DATABASE: core_test + when: + event: + - push + - tag + - pull_request \ No newline at end of file diff --git a/core/.gitignore b/core/.gitignore new file mode 100644 index 00000000..98e6ef67 --- /dev/null +++ b/core/.gitignore @@ -0,0 +1 @@ +*.db diff --git a/core/LICENSE b/core/LICENSE new file mode 100644 index 00000000..11307978 --- /dev/null +++ b/core/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013 - 2015 Lunny Xiao +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/core/README.md b/core/README.md new file mode 100644 index 00000000..54436b68 --- /dev/null +++ b/core/README.md @@ -0,0 +1,118 @@ +Core is a lightweight wrapper of sql.DB. + +[![Build Status](https://drone.gitea.com/api/badges/xorm/core/status.svg)](https://drone.gitea.com/xorm/core) +[![Test Coverage](https://gocover.io/_badge/xorm.io/core)](https://gocover.io/xorm.io/core) +[![Go Report Card](https://goreportcard.com/badge/code.gitea.io/gitea)](https://goreportcard.com/report/xorm.io/core) + +# Open +```Go +db, _ := core.Open(db, connstr) +``` + +# SetMapper +```Go +db.SetMapper(SameMapper()) +``` + +## Scan usage + +### Scan +```Go +rows, _ := db.Query() +for rows.Next() { + rows.Scan() +} +``` + +### ScanMap +```Go +rows, _ := db.Query() +for rows.Next() { + rows.ScanMap() +``` + +### ScanSlice + +You can use `[]string`, `[][]byte`, `[]interface{}`, `[]*string`, `[]sql.NullString` to ScanSclice. Notice, slice's length should be equal or less than select columns. + +```Go +rows, _ := db.Query() +cols, _ := rows.Columns() +for rows.Next() { + var s = make([]string, len(cols)) + rows.ScanSlice(&s) +} +``` + +```Go +rows, _ := db.Query() +cols, _ := rows.Columns() +for rows.Next() { + var s = make([]*string, len(cols)) + rows.ScanSlice(&s) +} +``` + +### ScanStruct +```Go +rows, _ := db.Query() +for rows.Next() { + rows.ScanStructByName() + rows.ScanStructByIndex() +} +``` + +## Query usage +```Go +rows, err := db.Query("select * from table where name = ?", name) + +user = User{ + Name:"lunny", +} +rows, err := db.QueryStruct("select * from table where name = ?Name", + &user) + +var user = map[string]interface{}{ + "name": "lunny", +} +rows, err = db.QueryMap("select * from table where name = ?name", + &user) +``` + +## QueryRow usage +```Go +row := db.QueryRow("select * from table where name = ?", name) + +user = User{ + Name:"lunny", +} +row := db.QueryRowStruct("select * from table where name = ?Name", + &user) + +var user = map[string]interface{}{ + "name": "lunny", +} +row = db.QueryRowMap("select * from table where name = ?name", + &user) +``` + +## Exec usage +```Go +db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", name, title, age, alias...) + +user = User{ + Name:"lunny", + Title:"test", + Age: 18, +} +result, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + +var user = map[string]interface{}{ + "Name": "lunny", + "Title": "test", + "Age": 18, +} +result, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) +``` \ No newline at end of file diff --git a/core/benchmark.sh b/core/benchmark.sh new file mode 100755 index 00000000..eab9e57e --- /dev/null +++ b/core/benchmark.sh @@ -0,0 +1 @@ +go test -v -bench=. -run=XXX diff --git a/core/cache.go b/core/cache.go new file mode 100644 index 00000000..982abe6a --- /dev/null +++ b/core/cache.go @@ -0,0 +1,95 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "strings" + "time" +) + +const ( + // CacheExpired is default cache expired time + CacheExpired = 60 * time.Minute + // CacheMaxMemory is not use now + CacheMaxMemory = 256 + // CacheGcInterval represents interval time to clear all expired nodes + CacheGcInterval = 10 * time.Minute + // CacheGcMaxRemoved represents max nodes removed when gc + CacheGcMaxRemoved = 20 +) + +// list all the errors +var ( + ErrCacheMiss = errors.New("xorm/cache: key not found") + ErrNotStored = errors.New("xorm/cache: not stored") +) + +// CacheStore is a interface to store cache +type CacheStore interface { + // key is primary key or composite primary key + // value is struct's pointer + // key format : -p--... + Put(key string, value interface{}) error + Get(key string) (interface{}, error) + Del(key string) error +} + +// Cacher is an interface to provide cache +// id format : u--... +type Cacher interface { + GetIds(tableName, sql string) interface{} + GetBean(tableName string, id string) interface{} + PutIds(tableName, sql string, ids interface{}) + PutBean(tableName string, id string, obj interface{}) + DelIds(tableName, sql string) + DelBean(tableName string, id string) + ClearIds(tableName string) + ClearBeans(tableName string) +} + +func encodeIds(ids []PK) (string, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(ids) + + return buf.String(), err +} + +func decodeIds(s string) ([]PK, error) { + pks := make([]PK, 0) + + dec := gob.NewDecoder(strings.NewReader(s)) + err := dec.Decode(&pks) + + return pks, err +} + +// GetCacheSql returns cacher PKs via SQL +func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]PK, error) { + bytes := m.GetIds(tableName, GenSqlKey(sql, args)) + if bytes == nil { + return nil, errors.New("Not Exist") + } + return decodeIds(bytes.(string)) +} + +// PutCacheSql puts cacher SQL and PKs +func PutCacheSql(m Cacher, ids []PK, tableName, sql string, args interface{}) error { + bytes, err := encodeIds(ids) + if err != nil { + return err + } + m.PutIds(tableName, GenSqlKey(sql, args), bytes) + return nil +} + +// GenSqlKey generates cache key +func GenSqlKey(sql string, args interface{}) string { + return fmt.Sprintf("%v-%v", sql, args) +} diff --git a/core/column.go b/core/column.go new file mode 100644 index 00000000..8f375db5 --- /dev/null +++ b/core/column.go @@ -0,0 +1,166 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +const ( + TWOSIDES = iota + 1 + ONLYTODB + ONLYFROMDB +) + +// Column defines database column +type Column struct { + Name string + TableName string + FieldName string + SQLType SQLType + IsJSON bool + Length int + Length2 int + Nullable bool + Default string + Indexes map[string]int + IsPrimaryKey bool + IsAutoIncrement bool + MapType int + IsCreated bool + IsUpdated bool + IsDeleted bool + IsCascade bool + IsVersion bool + DefaultIsEmpty bool // false means column has no default set, but not default value is empty + EnumOptions map[string]int + SetOptions map[string]int + DisableTimeZone bool + TimeZone *time.Location // column specified time zone + Comment string +} + +// NewColumn creates a new column +func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { + return &Column{ + Name: name, + TableName: "", + FieldName: fieldName, + SQLType: sqlType, + Length: len1, + Length2: len2, + Nullable: nullable, + Default: "", + Indexes: make(map[string]int), + IsPrimaryKey: false, + IsAutoIncrement: false, + MapType: TWOSIDES, + IsCreated: false, + IsUpdated: false, + IsDeleted: false, + IsCascade: false, + IsVersion: false, + DefaultIsEmpty: true, // default should be no default + EnumOptions: make(map[string]int), + Comment: "", + } +} + +// String generate column description string according dialect +func (col *Column) String(d Dialect) string { + sql := d.Quote(col.Name) + " " + + sql += d.SqlType(col) + " " + + if col.IsPrimaryKey { + sql += "PRIMARY KEY " + if col.IsAutoIncrement { + sql += d.AutoIncrStr() + " " + } + } + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + if d.ShowCreateNull() { + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + } + + return sql +} + +// StringNoPk generate column description string according dialect without primary keys +func (col *Column) StringNoPk(d Dialect) string { + sql := d.Quote(col.Name) + " " + + sql += d.SqlType(col) + " " + + if col.Default != "" { + sql += "DEFAULT " + col.Default + " " + } + + if d.ShowCreateNull() { + if col.Nullable { + sql += "NULL " + } else { + sql += "NOT NULL " + } + } + + return sql +} + +// ValueOf returns column's filed of struct's value +func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) { + dataStruct := reflect.Indirect(reflect.ValueOf(bean)) + return col.ValueOfV(&dataStruct) +} + +// ValueOfV returns column's filed of struct's value accept reflevt value +func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { + var fieldValue reflect.Value + fieldPath := strings.Split(col.FieldName, ".") + + if dataStruct.Type().Kind() == reflect.Map { + keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1]) + fieldValue = dataStruct.MapIndex(keyValue) + return &fieldValue, nil + } else if dataStruct.Type().Kind() == reflect.Interface { + structValue := reflect.ValueOf(dataStruct.Interface()) + dataStruct = &structValue + } + + level := len(fieldPath) + fieldValue = dataStruct.FieldByName(fieldPath[0]) + for i := 0; i < level-1; i++ { + if !fieldValue.IsValid() { + break + } + if fieldValue.Kind() == reflect.Struct { + fieldValue = fieldValue.FieldByName(fieldPath[i+1]) + } else if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1]) + } else { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) + } + } + + if !fieldValue.IsValid() { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) + } + + return &fieldValue, nil +} diff --git a/core/converstion.go b/core/converstion.go new file mode 100644 index 00000000..9703c36e --- /dev/null +++ b/core/converstion.go @@ -0,0 +1,12 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +// Conversion is an interface. A type implements Conversion will according +// the custom method to fill into database and retrieve from database. +type Conversion interface { + FromDB([]byte) error + ToDB() ([]byte, error) +} diff --git a/core/db.go b/core/db.go new file mode 100644 index 00000000..4847937c --- /dev/null +++ b/core/db.go @@ -0,0 +1,227 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "sync" +) + +var ( + // DefaultCacheSize sets the default cache size + DefaultCacheSize = 200 +) + +func MapToSlice(query string, mp interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return "", []interface{}{}, ErrNoMapPointer + } + + args := make([]interface{}, 0, len(vv.Elem().MapKeys())) + var err error + query = re.ReplaceAllStringFunc(query, func(src string) string { + v := vv.Elem().MapIndex(reflect.ValueOf(src[1:])) + if !v.IsValid() { + err = fmt.Errorf("map key %s is missing", src[1:]) + } else { + args = append(args, v.Interface()) + } + return "?" + }) + + return query, args, err +} + +func StructToSlice(query string, st interface{}) (string, []interface{}, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return "", []interface{}{}, ErrNoStructPointer + } + + args := make([]interface{}, 0) + var err error + query = re.ReplaceAllStringFunc(query, func(src string) string { + fv := vv.Elem().FieldByName(src[1:]).Interface() + if v, ok := fv.(driver.Valuer); ok { + var value driver.Value + value, err = v.Value() + if err != nil { + return "?" + } + args = append(args, value) + } else { + args = append(args, fv) + } + return "?" + }) + if err != nil { + return "", []interface{}{}, err + } + return query, args, nil +} + +type cacheStruct struct { + value reflect.Value + idx int +} + +// DB is a wrap of sql.DB with extra contents +type DB struct { + *sql.DB + Mapper IMapper + reflectCache map[reflect.Type]*cacheStruct + reflectCacheMutex sync.RWMutex +} + +// Open opens a database +func Open(driverName, dataSourceName string) (*DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + return &DB{ + DB: db, + Mapper: NewCacheMapper(&SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + }, nil +} + +// FromDB creates a DB from a sql.DB +func FromDB(db *sql.DB) *DB { + return &DB{ + DB: db, + Mapper: NewCacheMapper(&SnakeMapper{}), + reflectCache: make(map[reflect.Type]*cacheStruct), + } +} + +func (db *DB) reflectNew(typ reflect.Type) reflect.Value { + db.reflectCacheMutex.Lock() + defer db.reflectCacheMutex.Unlock() + cs, ok := db.reflectCache[typ] + if !ok || cs.idx+1 > DefaultCacheSize-1 { + cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), DefaultCacheSize, DefaultCacheSize), 0} + db.reflectCache[typ] = cs + } else { + cs.idx = cs.idx + 1 + } + return cs.value.Index(cs.idx).Addr() +} + +// QueryContext overwrites sql.DB.QueryContext +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + rows, err := db.DB.QueryContext(ctx, query, args...) + if err != nil { + if rows != nil { + rows.Close() + } + return nil, err + } + return &Rows{rows, db}, nil +} + +// Query overwrites sql.DB.Query +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +// QueryMapContext executes query with parameters via map and context +func (db *DB) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return db.QueryContext(ctx, query, args...) +} + +// QueryMap executes query with parameters via map +func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { + return db.QueryMapContext(context.Background(), query, mp) +} + +func (db *DB) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return db.QueryContext(ctx, query, args...) +} + +func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { + return db.QueryStructContext(context.Background(), query, st) +} + +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return &Row{nil, err} + } + return &Row{rows, nil} +} + +func (db *DB) QueryRow(query string, args ...interface{}) *Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +func (db *DB) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err} + } + return db.QueryRowContext(ctx, query, args...) +} + +func (db *DB) QueryRowMap(query string, mp interface{}) *Row { + return db.QueryRowMapContext(context.Background(), query, mp) +} + +func (db *DB) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err} + } + return db.QueryRowContext(ctx, query, args...) +} + +func (db *DB) QueryRowStruct(query string, st interface{}) *Row { + return db.QueryRowStructContext(context.Background(), query, st) +} + +var ( + re = regexp.MustCompile(`[?](\w+)`) +) + +// ExecMapContext exec map with context.Context +// insert into (name) values (?) +// insert into (name) values (?name) +func (db *DB) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return db.DB.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) { + return db.ExecMapContext(context.Background(), query, mp) +} + +func (db *DB) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return db.DB.ExecContext(ctx, query, args...) +} + +func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) { + return db.ExecStructContext(context.Background(), query, st) +} diff --git a/core/db_test.go b/core/db_test.go new file mode 100644 index 00000000..affc1f91 --- /dev/null +++ b/core/db_test.go @@ -0,0 +1,683 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "errors" + "flag" + "os" + "testing" + "time" + + _ "github.com/go-sql-driver/mysql" + _ "github.com/mattn/go-sqlite3" +) + +var ( + dbtype = flag.String("dbtype", "mysql", "database type") + dbConn = flag.String("dbConn", "root:@/core_test?charset=utf8", "database connect string") + createTableSql string +) + +func TestMain(m *testing.M) { + flag.Parse() + + switch *dbtype { + case "sqlite3": + createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" + case "mysql": + fallthrough + default: + createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " + + "`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);" + } + + exitCode := m.Run() + + os.Exit(exitCode) +} + +func testOpen() (*DB, error) { + switch *dbtype { + case "sqlite3": + os.Remove("./test.db") + return Open("sqlite3", "./test.db") + case "mysql": + return Open("mysql", *dbConn) + default: + panic("no db type") + } +} + +func BenchmarkOriQuery(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var Id int64 + var Name, Title, Alias, NickName string + var Age float32 + var Created NullTime + err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) + if err != nil { + b.Error(err) + } + //fmt.Println(Id, Name, Title, Age, Alias, NickName) + } + rows.Close() + } +} + +type User struct { + Id int64 + Name string + Title string + Age float32 + Alias string + NickName string + Created NullTime +} + +func BenchmarkStructQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByIndex(&user) + if err != nil { + b.Error(err) + } + if user.Name != "xlw" { + b.Log(user) + b.Error(errors.New("name should be xlw")) + } + } + rows.Close() + } +} + +func BenchmarkStruct2Query(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + db.Mapper = NewCacheMapper(&SnakeMapper{}) + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + b.Error(err) + } + if user.Name != "xlw" { + b.Log(user) + b.Error(errors.New("name should be xlw")) + } + } + rows.Close() + } +} + +func BenchmarkSliceInterfaceQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]interface{}, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + b.Log(slice) + switch slice[1].(type) { + case *string: + if *slice[1].(*string) != "xlw" { + b.Error(errors.New("name should be xlw")) + } + case []byte: + if string(slice[1].([]byte)) != "xlw" { + b.Error(errors.New("name should be xlw")) + } + } + } + + rows.Close() + } +} + +/*func BenchmarkSliceBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([][]byte, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if string(slice[1]) != "xlw" { + fmt.Println(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} +*/ + +func BenchmarkSliceStringQuery(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + cols, err := rows.Columns() + if err != nil { + b.Error(err) + } + + for rows.Next() { + slice := make([]*string, len(cols)) + err = rows.ScanSlice(&slice) + if err != nil { + b.Error(err) + } + if (*slice[1]) != "xlw" { + b.Log(slice) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} + +func BenchmarkMapInterfaceQuery(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string]interface{}) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + switch m["name"].(type) { + case string: + if m["name"].(string) != "xlw" { + b.Log(m) + b.Error(errors.New("name should be xlw")) + } + case []byte: + if string(m["name"].([]byte)) != "xlw" { + b.Log(m) + b.Error(errors.New("name should be xlw")) + } + } + } + + rows.Close() + } +} + +/*func BenchmarkMapBytesQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string][]byte) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + if string(m["name"]) != "xlw" { + fmt.Println(m) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +} +*/ +/* +func BenchmarkMapStringQuery(b *testing.B) { + b.StopTimer() + os.Remove("./test.db") + db, err := Open("sqlite3", "./test.db") + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + for i := 0; i < 50; i++ { + _, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("select * from user") + if err != nil { + b.Error(err) + } + + for rows.Next() { + m := make(map[string]string) + err = rows.ScanMap(&m) + if err != nil { + b.Error(err) + } + if m["name"] != "xlw" { + fmt.Println(m) + b.Error(errors.New("name should be xlw")) + } + } + + rows.Close() + } +}*/ + +func BenchmarkExec(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + for i := 0; i < b.N; i++ { + _, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", + "xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now()) + if err != nil { + b.Error(err) + } + } +} + +func BenchmarkExecMap(b *testing.B) { + b.StopTimer() + + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + "created": time.Now(), + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name, created) "+ + "values (?name,?title,?age,?alias,?nick_name,?created)", + &mp) + if err != nil { + b.Error(err) + } + } +} + +func TestExecMap(t *testing.T) { + db, err := testOpen() + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + t.Error(err) + } + + mp := map[string]interface{}{ + "name": "xlw", + "title": "tester", + "age": 1.2, + "alias": "lunny", + "nick_name": "lunny xiao", + "created": time.Now(), + } + + _, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?name,?title,?age,?alias,?nick_name,?created)", + &mp) + if err != nil { + t.Error(err) + } + + rows, err := db.Query("select * from user") + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + t.Error(err) + } + t.Log("--", user) + } +} + +func TestExecStruct(t *testing.T) { + db, err := testOpen() + if err != nil { + t.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + t.Error(err) + } + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + Created: NullTime(time.Now()), + } + + _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + if err != nil { + t.Error(err) + } + + rows, err := db.QueryStruct("select * from user where `name` = ?Name", &user) + if err != nil { + t.Error(err) + } + + for rows.Next() { + var user User + err = rows.ScanStructByName(&user) + if err != nil { + t.Error(err) + } + t.Log("1--", user) + } +} + +func BenchmarkExecStruct(b *testing.B) { + b.StopTimer() + db, err := testOpen() + if err != nil { + b.Error(err) + } + defer db.Close() + + _, err = db.Exec(createTableSql) + if err != nil { + b.Error(err) + } + + b.StartTimer() + + user := User{Name: "xlw", + Title: "tester", + Age: 1.2, + Alias: "lunny", + NickName: "lunny xiao", + Created: NullTime(time.Now()), + } + + for i := 0; i < b.N; i++ { + _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ + "values (?Name,?Title,?Age,?Alias,?NickName,?Created)", + &user) + if err != nil { + b.Error(err) + } + } +} diff --git a/core/dialect.go b/core/dialect.go new file mode 100644 index 00000000..c166596c --- /dev/null +++ b/core/dialect.go @@ -0,0 +1,327 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "fmt" + "strings" + "time" +) + +type DbType string + +type Uri struct { + DbType DbType + Proto string + Host string + Port string + DbName string + User string + Passwd string + Charset string + Laddr string + Raddr string + Timeout time.Duration + Schema string +} + +// a dialect is a driver's wrapper +type Dialect interface { + SetLogger(logger ILogger) + Init(*DB, *Uri, string, string) error + URI() *Uri + DB() *DB + DBType() DbType + SqlType(*Column) string + FormatBytes(b []byte) string + + DriverName() string + DataSourceName() string + + IsReserved(string) bool + Quote(string) string + + AndStr() string + OrStr() string + EqStr() string + RollBackStr() string + AutoIncrStr() string + + SupportInsertMany() bool + SupportEngine() bool + SupportCharset() bool + SupportDropIfExists() bool + IndexOnTable() bool + ShowCreateNull() bool + + IndexCheckSql(tableName, idxName string) (string, []interface{}) + TableCheckSql(tableName string) (string, []interface{}) + + IsColumnExist(tableName string, colName string) (bool, error) + + CreateTableSql(table *Table, tableName, storeEngine, charset string) string + DropTableSql(tableName string) string + CreateIndexSql(tableName string, index *Index) string + DropIndexSql(tableName string, index *Index) string + + ModifyColumnSql(tableName string, col *Column) string + + ForUpdateSql(query string) string + + // CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error + // MustDropTable(tableName string) error + + GetColumns(tableName string) ([]string, map[string]*Column, error) + GetTables() ([]*Table, error) + GetIndexes(tableName string) (map[string]*Index, error) + + Filters() []Filter + SetParams(params map[string]string) +} + +func OpenDialect(dialect Dialect) (*DB, error) { + return Open(dialect.DriverName(), dialect.DataSourceName()) +} + +// Base represents a basic dialect and all real dialects could embed this struct +type Base struct { + db *DB + dialect Dialect + driverName string + dataSourceName string + logger ILogger + *Uri +} + +func (b *Base) DB() *DB { + return b.db +} + +func (b *Base) SetLogger(logger ILogger) { + b.logger = logger +} + +func (b *Base) Init(db *DB, dialect Dialect, uri *Uri, drivername, dataSourceName string) error { + b.db, b.dialect, b.Uri = db, dialect, uri + b.driverName, b.dataSourceName = drivername, dataSourceName + return nil +} + +func (b *Base) URI() *Uri { + return b.Uri +} + +func (b *Base) DBType() DbType { + return b.Uri.DbType +} + +func (b *Base) FormatBytes(bs []byte) string { + return fmt.Sprintf("0x%x", bs) +} + +func (b *Base) DriverName() string { + return b.driverName +} + +func (b *Base) ShowCreateNull() bool { + return true +} + +func (b *Base) DataSourceName() string { + return b.dataSourceName +} + +func (b *Base) AndStr() string { + return "AND" +} + +func (b *Base) OrStr() string { + return "OR" +} + +func (b *Base) EqStr() string { + return "=" +} + +func (db *Base) RollBackStr() string { + return "ROLL BACK" +} + +func (db *Base) SupportDropIfExists() bool { + return true +} + +func (db *Base) DropTableSql(tableName string) string { + quote := db.dialect.Quote + return fmt.Sprintf("DROP TABLE IF EXISTS %s", quote(tableName)) +} + +func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { + db.LogSQL(query, args) + rows, err := db.DB().Query(query, args...) + if err != nil { + return false, err + } + defer rows.Close() + + if rows.Next() { + return true, nil + } + return false, nil +} + +func (db *Base) IsColumnExist(tableName, colName string) (bool, error) { + query := fmt.Sprintf( + "SELECT %v FROM %v.%v WHERE %v = ? AND %v = ? AND %v = ?", + db.dialect.Quote("COLUMN_NAME"), + db.dialect.Quote("INFORMATION_SCHEMA"), + db.dialect.Quote("COLUMNS"), + db.dialect.Quote("TABLE_SCHEMA"), + db.dialect.Quote("TABLE_NAME"), + db.dialect.Quote("COLUMN_NAME"), + ) + return db.HasRecords(query, db.DbName, tableName, colName) +} + +/* +func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error { + sql, args := db.dialect.TableCheckSql(tableName) + rows, err := db.DB().Query(sql, args...) + if db.Logger != nil { + db.Logger.Info("[sql]", sql, args) + } + if err != nil { + return err + } + defer rows.Close() + + if rows.Next() { + return nil + } + + sql = db.dialect.CreateTableSql(table, tableName, storeEngine, charset) + _, err = db.DB().Exec(sql) + if db.Logger != nil { + db.Logger.Info("[sql]", sql) + } + return err +}*/ + +func (db *Base) CreateIndexSql(tableName string, index *Index) string { + quote := db.dialect.Quote + var unique string + var idxName string + if index.Type == UniqueType { + unique = " UNIQUE" + } + idxName = index.XName(tableName) + return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique, + quote(idxName), quote(tableName), + quote(strings.Join(index.Cols, quote(",")))) +} + +func (db *Base) DropIndexSql(tableName string, index *Index) string { + quote := db.dialect.Quote + var name string + if index.IsRegular { + name = index.XName(tableName) + } else { + name = index.Name + } + return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName)) +} + +func (db *Base) ModifyColumnSql(tableName string, col *Column) string { + return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, col.StringNoPk(db.dialect)) +} + +func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset string) string { + var sql string + sql = "CREATE TABLE IF NOT EXISTS " + if tableName == "" { + tableName = table.Name + } + + sql += b.dialect.Quote(tableName) + sql += " (" + + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys + + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.IsPrimaryKey && len(pkList) == 1 { + sql += col.String(b.dialect) + } else { + sql += col.StringNoPk(b.dialect) + } + sql = strings.TrimSpace(sql) + if b.DriverName() == MYSQL && len(col.Comment) > 0 { + sql += " COMMENT '" + col.Comment + "'" + } + sql += ", " + } + + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(","))) + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" + + if b.dialect.SupportEngine() && storeEngine != "" { + sql += " ENGINE=" + storeEngine + } + if b.dialect.SupportCharset() { + if len(charset) == 0 { + charset = b.dialect.URI().Charset + } + if len(charset) > 0 { + sql += " DEFAULT CHARSET " + charset + } + } + + return sql +} + +func (b *Base) ForUpdateSql(query string) string { + return query + " FOR UPDATE" +} + +func (b *Base) LogSQL(sql string, args []interface{}) { + if b.logger != nil && b.logger.IsShowSQL() { + if len(args) > 0 { + b.logger.Infof("[SQL] %v %v", sql, args) + } else { + b.logger.Infof("[SQL] %v", sql) + } + } +} + +func (b *Base) SetParams(params map[string]string) { +} + +var ( + dialects = map[string]func() Dialect{} +) + +// RegisterDialect register database dialect +func RegisterDialect(dbName DbType, dialectFunc func() Dialect) { + if dialectFunc == nil { + panic("core: Register dialect is nil") + } + dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect +} + +// QueryDialect query if registered database dialect +func QueryDialect(dbName DbType) Dialect { + if d, ok := dialects[strings.ToLower(string(dbName))]; ok { + return d() + } + return nil +} diff --git a/core/driver.go b/core/driver.go new file mode 100644 index 00000000..ceef4ba6 --- /dev/null +++ b/core/driver.go @@ -0,0 +1,31 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +type Driver interface { + Parse(string, string) (*Uri, error) +} + +var ( + drivers = map[string]Driver{} +) + +func RegisterDriver(driverName string, driver Driver) { + if driver == nil { + panic("core: Register driver is nil") + } + if _, dup := drivers[driverName]; dup { + panic("core: Register called twice for driver " + driverName) + } + drivers[driverName] = driver +} + +func QueryDriver(driverName string) Driver { + return drivers[driverName] +} + +func RegisteredDriverSize() int { + return len(drivers) +} diff --git a/core/error.go b/core/error.go new file mode 100644 index 00000000..1fd18348 --- /dev/null +++ b/core/error.go @@ -0,0 +1,14 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import "errors" + +var ( + // ErrNoMapPointer represents error when no map pointer + ErrNoMapPointer = errors.New("mp should be a map's pointer") + // ErrNoStructPointer represents error when no struct pointer + ErrNoStructPointer = errors.New("mp should be a struct's pointer") +) diff --git a/core/filter.go b/core/filter.go new file mode 100644 index 00000000..55be9562 --- /dev/null +++ b/core/filter.go @@ -0,0 +1,93 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "fmt" + "strings" +) + +// Filter is an interface to filter SQL +type Filter interface { + Do(sql string, dialect Dialect, table *Table) string +} + +// QuoteFilter filter SQL replace ` to database's own quote character +type QuoteFilter struct { +} + +func (s *QuoteFilter) Do(sql string, dialect Dialect, table *Table) string { + dummy := dialect.Quote("") + if len(dummy) != 2 { + return sql + } + prefix, suffix := dummy[0], dummy[1] + raw := []byte(sql) + for i, cnt := 0, 0; i < len(raw); i = i + 1 { + if raw[i] == '`' { + if cnt%2 == 0 { + raw[i] = prefix + } else { + raw[i] = suffix + } + cnt++ + } + } + return string(raw) +} + +// IdFilter filter SQL replace (id) to primary key column name +type IdFilter struct { +} + +type Quoter struct { + dialect Dialect +} + +func NewQuoter(dialect Dialect) *Quoter { + return &Quoter{dialect} +} + +func (q *Quoter) Quote(content string) string { + return q.dialect.Quote(content) +} + +func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string { + quoter := NewQuoter(dialect) + if table != nil && len(table.PrimaryKeys) == 1 { + sql = strings.Replace(sql, " `(id)` ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) + sql = strings.Replace(sql, " "+quoter.Quote("(id)")+" ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) + return strings.Replace(sql, " (id) ", " "+quoter.Quote(table.PrimaryKeys[0])+" ", -1) + } + return sql +} + +// SeqFilter filter SQL replace ?, ? ... to $1, $2 ... +type SeqFilter struct { + Prefix string + Start int +} + +func convertQuestionMark(sql, prefix string, start int) string { + var buf strings.Builder + var beginSingleQuote bool + var index = start + for _, c := range sql { + if !beginSingleQuote && c == '?' { + buf.WriteString(fmt.Sprintf("%s%v", prefix, index)) + index++ + } else { + if c == '\'' { + beginSingleQuote = !beginSingleQuote + } + buf.WriteRune(c) + } + } + return buf.String() +} + +func (s *SeqFilter) Do(sql string, dialect Dialect, table *Table) string { + return convertQuestionMark(sql, s.Prefix, s.Start) +} diff --git a/core/filter_test.go b/core/filter_test.go new file mode 100644 index 00000000..a6b7b3bb --- /dev/null +++ b/core/filter_test.go @@ -0,0 +1,39 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type quoterOnly struct { + Dialect +} + +func (q *quoterOnly) Quote(item string) string { + return "[" + item + "]" +} + +func TestQuoteFilter_Do(t *testing.T) { + f := QuoteFilter{} + sql := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" + res := f.Do(sql, new(quoterOnly), nil) + assert.EqualValues(t, + "SELECT [COLUMN_NAME] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_SCHEMA] = ? AND [TABLE_NAME] = ? AND [COLUMN_NAME] = ?", + res, + ) +} + +func TestSeqFilter(t *testing.T) { + var kases = map[string]string{ + "SELECT * FROM TABLE1 WHERE a=? AND b=?": "SELECT * FROM TABLE1 WHERE a=$1 AND b=$2", + "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=? AND b=?": "SELECT 1, '???', '2006-01-02 15:04:05' FROM TABLE1 WHERE a=$1 AND b=$2", + "select '1''?' from issue": "select '1''?' from issue", + "select '1\\??' from issue": "select '1\\??' from issue", + "select '1\\\\',? from issue": "select '1\\\\',$1 from issue", + "select '1\\''?',? from issue": "select '1\\''?',$1 from issue", + } + for sql, result := range kases { + assert.EqualValues(t, result, convertQuestionMark(sql, "$", 1)) + } +} diff --git a/core/go.mod b/core/go.mod new file mode 100644 index 00000000..9dc9c1f2 --- /dev/null +++ b/core/go.mod @@ -0,0 +1,8 @@ +module xorm.io/core + +require ( + github.com/go-sql-driver/mysql v1.4.1 + github.com/mattn/go-sqlite3 v1.10.0 + github.com/stretchr/testify v1.4.0 + google.golang.org/appengine v1.6.0 // indirect +) diff --git a/core/go.sum b/core/go.sum new file mode 100644 index 00000000..172009ed --- /dev/null +++ b/core/go.sum @@ -0,0 +1,20 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= +github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +google.golang.org/appengine v1.6.0 h1:Tfd7cKwKbFRsI8RMAD3oqqw7JPFRrvFlOsfbgVkjOOw= +google.golang.org/appengine v1.6.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/core/ilogger.go b/core/ilogger.go new file mode 100644 index 00000000..0c17750c --- /dev/null +++ b/core/ilogger.go @@ -0,0 +1,37 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +// LogLevel defines a log level +type LogLevel int + +// enumerate all LogLevels +const ( + // !nashtsai! following level also match syslog.Priority value + LOG_DEBUG LogLevel = iota + LOG_INFO + LOG_WARNING + LOG_ERR + LOG_OFF + LOG_UNKNOWN +) + +// ILogger is a logger interface +type ILogger interface { + Debug(v ...interface{}) + Debugf(format string, v ...interface{}) + Error(v ...interface{}) + Errorf(format string, v ...interface{}) + Info(v ...interface{}) + Infof(format string, v ...interface{}) + Warn(v ...interface{}) + Warnf(format string, v ...interface{}) + + Level() LogLevel + SetLevel(l LogLevel) + + ShowSQL(show ...bool) + IsShowSQL() bool +} diff --git a/core/index.go b/core/index.go new file mode 100644 index 00000000..129b5439 --- /dev/null +++ b/core/index.go @@ -0,0 +1,72 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "fmt" + "strings" +) + +// enumerate all index types +const ( + IndexType = iota + 1 + UniqueType +) + +// Index represents a database index +type Index struct { + IsRegular bool + Name string + Type int + Cols []string +} + +func (index *Index) XName(tableName string) string { + if !strings.HasPrefix(index.Name, "UQE_") && + !strings.HasPrefix(index.Name, "IDX_") { + tableParts := strings.Split(strings.Replace(tableName, `"`, "", -1), ".") + tableName = tableParts[len(tableParts)-1] + if index.Type == UniqueType { + return fmt.Sprintf("UQE_%v_%v", tableName, index.Name) + } + return fmt.Sprintf("IDX_%v_%v", tableName, index.Name) + } + return index.Name +} + +// AddColumn add columns which will be composite index +func (index *Index) AddColumn(cols ...string) { + for _, col := range cols { + index.Cols = append(index.Cols, col) + } +} + +func (index *Index) Equal(dst *Index) bool { + if index.Type != dst.Type { + return false + } + if len(index.Cols) != len(dst.Cols) { + return false + } + + for i := 0; i < len(index.Cols); i++ { + var found bool + for j := 0; j < len(dst.Cols); j++ { + if index.Cols[i] == dst.Cols[j] { + found = true + break + } + } + if !found { + return false + } + } + return true +} + +// NewIndex new an index object +func NewIndex(name string, indexType int) *Index { + return &Index{true, name, indexType, make([]string, 0)} +} diff --git a/core/mapper.go b/core/mapper.go new file mode 100644 index 00000000..4df05cb8 --- /dev/null +++ b/core/mapper.go @@ -0,0 +1,258 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "strings" + "sync" +) + +// IMapper represents a name convertation between struct's fields name and table's column name +type IMapper interface { + Obj2Table(string) string + Table2Obj(string) string +} + +type CacheMapper struct { + oriMapper IMapper + obj2tableCache map[string]string + obj2tableMutex sync.RWMutex + table2objCache map[string]string + table2objMutex sync.RWMutex +} + +func NewCacheMapper(mapper IMapper) *CacheMapper { + return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string), + table2objCache: make(map[string]string), + } +} + +func (m *CacheMapper) Obj2Table(o string) string { + m.obj2tableMutex.RLock() + t, ok := m.obj2tableCache[o] + m.obj2tableMutex.RUnlock() + if ok { + return t + } + + t = m.oriMapper.Obj2Table(o) + m.obj2tableMutex.Lock() + m.obj2tableCache[o] = t + m.obj2tableMutex.Unlock() + return t +} + +func (m *CacheMapper) Table2Obj(t string) string { + m.table2objMutex.RLock() + o, ok := m.table2objCache[t] + m.table2objMutex.RUnlock() + if ok { + return o + } + + o = m.oriMapper.Table2Obj(t) + m.table2objMutex.Lock() + m.table2objCache[t] = o + m.table2objMutex.Unlock() + return o +} + +// SameMapper implements IMapper and provides same name between struct and +// database table +type SameMapper struct { +} + +func (m SameMapper) Obj2Table(o string) string { + return o +} + +func (m SameMapper) Table2Obj(t string) string { + return t +} + +// SnakeMapper implements IMapper and provides name transaltion between +// struct and database table +type SnakeMapper struct { +} + +func snakeCasedName(name string) string { + newstr := make([]rune, 0) + for idx, chr := range name { + if isUpper := 'A' <= chr && chr <= 'Z'; isUpper { + if idx > 0 { + newstr = append(newstr, '_') + } + chr -= ('A' - 'a') + } + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func (mapper SnakeMapper) Obj2Table(name string) string { + return snakeCasedName(name) +} + +func titleCasedName(name string) string { + newstr := make([]rune, 0) + upNextChar := true + + name = strings.ToLower(name) + + for _, chr := range name { + switch { + case upNextChar: + upNextChar = false + if 'a' <= chr && chr <= 'z' { + chr -= ('a' - 'A') + } + case chr == '_': + upNextChar = true + continue + } + + newstr = append(newstr, chr) + } + + return string(newstr) +} + +func (mapper SnakeMapper) Table2Obj(name string) string { + return titleCasedName(name) +} + +// GonicMapper implements IMapper. It will consider initialisms when mapping names. +// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid +type GonicMapper map[string]bool + +func isASCIIUpper(r rune) bool { + return 'A' <= r && r <= 'Z' +} + +func toASCIIUpper(r rune) rune { + if 'a' <= r && r <= 'z' { + r -= ('a' - 'A') + } + return r +} + +func gonicCasedName(name string) string { + newstr := make([]rune, 0, len(name)+3) + for idx, chr := range name { + if isASCIIUpper(chr) && idx > 0 { + if !isASCIIUpper(newstr[len(newstr)-1]) { + newstr = append(newstr, '_') + } + } + + if !isASCIIUpper(chr) && idx > 1 { + l := len(newstr) + if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) { + newstr = append(newstr, newstr[l-1]) + newstr[l-1] = '_' + } + } + + newstr = append(newstr, chr) + } + return strings.ToLower(string(newstr)) +} + +func (mapper GonicMapper) Obj2Table(name string) string { + return gonicCasedName(name) +} + +func (mapper GonicMapper) Table2Obj(name string) string { + newstr := make([]rune, 0) + + name = strings.ToLower(name) + parts := strings.Split(name, "_") + + for _, p := range parts { + _, isInitialism := mapper[strings.ToUpper(p)] + for i, r := range p { + if i == 0 || isInitialism { + r = toASCIIUpper(r) + } + newstr = append(newstr, r) + } + } + + return string(newstr) +} + +// LintGonicMapper is A GonicMapper that contains a list of common initialisms taken from golang/lint +var LintGonicMapper = GonicMapper{ + "API": true, + "ASCII": true, + "CPU": true, + "CSS": true, + "DNS": true, + "EOF": true, + "GUID": true, + "HTML": true, + "HTTP": true, + "HTTPS": true, + "ID": true, + "IP": true, + "JSON": true, + "LHS": true, + "QPS": true, + "RAM": true, + "RHS": true, + "RPC": true, + "SLA": true, + "SMTP": true, + "SSH": true, + "TLS": true, + "TTL": true, + "UI": true, + "UID": true, + "UUID": true, + "URI": true, + "URL": true, + "UTF8": true, + "VM": true, + "XML": true, + "XSRF": true, + "XSS": true, +} + +// PrefixMapper provides prefix table name support +type PrefixMapper struct { + Mapper IMapper + Prefix string +} + +func (mapper PrefixMapper) Obj2Table(name string) string { + return mapper.Prefix + mapper.Mapper.Obj2Table(name) +} + +func (mapper PrefixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):]) +} + +func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper { + return PrefixMapper{mapper, prefix} +} + +// SuffixMapper provides suffix table name support +type SuffixMapper struct { + Mapper IMapper + Suffix string +} + +func (mapper SuffixMapper) Obj2Table(name string) string { + return mapper.Mapper.Obj2Table(name) + mapper.Suffix +} + +func (mapper SuffixMapper) Table2Obj(name string) string { + return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)]) +} + +func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper { + return SuffixMapper{mapper, suffix} +} diff --git a/core/mapper_test.go b/core/mapper_test.go new file mode 100644 index 00000000..aa21c5c1 --- /dev/null +++ b/core/mapper_test.go @@ -0,0 +1,49 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "testing" +) + +func TestGonicMapperFromObj(t *testing.T) { + testCases := map[string]string{ + "HTTPLib": "http_lib", + "id": "id", + "ID": "id", + "IDa": "i_da", + "iDa": "i_da", + "IDAa": "id_aa", + "aID": "a_id", + "aaID": "aa_id", + "aaaID": "aaa_id", + "MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name", + } + + for in, expected := range testCases { + out := gonicCasedName(in) + if out != expected { + t.Errorf("Given %s, expected %s but got %s", in, expected, out) + } + } +} + +func TestGonicMapperToObj(t *testing.T) { + testCases := map[string]string{ + "http_lib": "HTTPLib", + "id": "ID", + "ida": "Ida", + "id_aa": "IDAa", + "aa_id": "AaID", + "my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName", + } + + for in, expected := range testCases { + out := LintGonicMapper.Table2Obj(in) + if out != expected { + t.Errorf("Given %s, expected %s but got %s", in, expected, out) + } + } +} diff --git a/core/pk.go b/core/pk.go new file mode 100644 index 00000000..05a7672d --- /dev/null +++ b/core/pk.go @@ -0,0 +1,30 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "bytes" + "encoding/gob" +) + +type PK []interface{} + +func NewPK(pks ...interface{}) *PK { + p := PK(pks) + return &p +} + +func (p *PK) ToString() (string, error) { + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(*p) + return buf.String(), err +} + +func (p *PK) FromString(content string) error { + dec := gob.NewDecoder(bytes.NewBufferString(content)) + err := dec.Decode(p) + return err +} diff --git a/core/pk_test.go b/core/pk_test.go new file mode 100644 index 00000000..8430f95e --- /dev/null +++ b/core/pk_test.go @@ -0,0 +1,36 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "reflect" + "testing" +) + +func TestPK(t *testing.T) { + p := NewPK(1, 3, "string") + str, err := p.ToString() + if err != nil { + t.Error(err) + } + t.Log(str) + + s := &PK{} + err = s.FromString(str) + if err != nil { + t.Error(err) + } + t.Log(s) + + if len(*p) != len(*s) { + t.Fatal("p", *p, "should be equal", *s) + } + + for i, ori := range *p { + if ori != (*s)[i] { + t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i])) + } + } +} diff --git a/core/rows.go b/core/rows.go new file mode 100644 index 00000000..a1e8bfbc --- /dev/null +++ b/core/rows.go @@ -0,0 +1,338 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "database/sql" + "errors" + "reflect" + "sync" +) + +type Rows struct { + *sql.Rows + db *DB +} + +func (rs *Rows) ToMapString() ([]map[string]string, error) { + cols, err := rs.Columns() + if err != nil { + return nil, err + } + + var results = make([]map[string]string, 0, 10) + for rs.Next() { + var record = make(map[string]string, len(cols)) + err = rs.ScanMap(&record) + if err != nil { + return nil, err + } + results = append(results, record) + } + return results, nil +} + +// scan data to a struct's pointer according field index +func (rs *Rows) ScanStructByIndex(dest ...interface{}) error { + if len(dest) == 0 { + return errors.New("at least one struct") + } + + vvvs := make([]reflect.Value, len(dest)) + for i, s := range dest { + vv := reflect.ValueOf(s) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + vvvs[i] = vv.Elem() + } + + cols, err := rs.Columns() + if err != nil { + return err + } + newDest := make([]interface{}, len(cols)) + + var i = 0 + for _, vvv := range vvvs { + for j := 0; j < vvv.NumField(); j++ { + newDest[i] = vvv.Field(j).Addr().Interface() + i = i + 1 + } + } + + return rs.Rows.Scan(newDest...) +} + +var ( + fieldCache = make(map[reflect.Type]map[string]int) + fieldCacheMutex sync.RWMutex +) + +func fieldByName(v reflect.Value, name string) reflect.Value { + t := v.Type() + fieldCacheMutex.RLock() + cache, ok := fieldCache[t] + fieldCacheMutex.RUnlock() + if !ok { + cache = make(map[string]int) + for i := 0; i < v.NumField(); i++ { + cache[t.Field(i).Name] = i + } + fieldCacheMutex.Lock() + fieldCache[t] = cache + fieldCacheMutex.Unlock() + } + + if i, ok := cache[name]; ok { + return v.Field(i) + } + + return reflect.Zero(t) +} + +// scan data to a struct's pointer according field name +func (rs *Rows) ScanStructByName(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return errors.New("dest should be a struct's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + var v EmptyScanner + for j, name := range cols { + f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name)) + if f.IsValid() { + newDest[j] = f.Addr().Interface() + } else { + newDest[j] = &v + } + } + + return rs.Rows.Scan(newDest...) +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (rs *Rows) ScanSlice(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice { + return errors.New("dest should be a slice's pointer") + } + + vvv := vv.Elem() + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + + for j := 0; j < len(cols); j++ { + if j >= vvv.Len() { + newDest[j] = reflect.New(vvv.Type().Elem()).Interface() + } else { + newDest[j] = vvv.Index(j).Addr().Interface() + } + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + srcLen := vvv.Len() + for i := srcLen; i < len(cols); i++ { + vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem()) + } + return nil +} + +// scan data to a map's pointer +func (rs *Rows) ScanMap(dest interface{}) error { + vv := reflect.ValueOf(dest) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return errors.New("dest should be a map's pointer") + } + + cols, err := rs.Columns() + if err != nil { + return err + } + + newDest := make([]interface{}, len(cols)) + vvv := vv.Elem() + + for i := range cols { + newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface() + } + + err = rs.Rows.Scan(newDest...) + if err != nil { + return err + } + + for i, name := range cols { + vname := reflect.ValueOf(name) + vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem()) + } + + return nil +} + +type Row struct { + rows *Rows + // One of these two will be non-nil: + err error // deferred error for easy chaining +} + +// ErrorRow return an error row +func ErrorRow(err error) *Row { + return &Row{ + err: err, + } +} + +// NewRow from rows +func NewRow(rows *Rows, err error) *Row { + return &Row{rows, err} +} + +func (row *Row) Columns() ([]string, error) { + if row.err != nil { + return nil, row.err + } + return row.rows.Columns() +} + +func (row *Row) Scan(dest ...interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ScanStructByName(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanStructByName(dest) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ScanStructByIndex(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanStructByIndex(dest) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (row *Row) ScanSlice(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanSlice(dest) + if err != nil { + return err + } + + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +// scan data to a map's pointer +func (row *Row) ScanMap(dest interface{}) error { + if row.err != nil { + return row.err + } + defer row.rows.Close() + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.ScanMap(dest) + if err != nil { + return err + } + + // Make sure the query can be processed to completion with no errors. + return row.rows.Close() +} + +func (row *Row) ToMapString() (map[string]string, error) { + cols, err := row.Columns() + if err != nil { + return nil, err + } + + var record = make(map[string]string, len(cols)) + err = row.ScanMap(&record) + if err != nil { + return nil, err + } + + return record, nil +} diff --git a/core/scan.go b/core/scan.go new file mode 100644 index 00000000..897b5341 --- /dev/null +++ b/core/scan.go @@ -0,0 +1,66 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "database/sql/driver" + "fmt" + "time" +) + +type NullTime time.Time + +var ( + _ driver.Valuer = NullTime{} +) + +func (ns *NullTime) Scan(value interface{}) error { + if value == nil { + return nil + } + return convertTime(ns, value) +} + +// Value implements the driver Valuer interface. +func (ns NullTime) Value() (driver.Value, error) { + if (time.Time)(ns).IsZero() { + return nil, nil + } + return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil +} + +func convertTime(dest *NullTime, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + t, err := time.Parse("2006-01-02 15:04:05", s) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case []uint8: + t, err := time.Parse("2006-01-02 15:04:05", string(s)) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case time.Time: + *dest = NullTime(s) + return nil + case nil: + default: + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) + } + return nil +} + +type EmptyScanner struct { +} + +func (EmptyScanner) Scan(src interface{}) error { + return nil +} diff --git a/core/stmt.go b/core/stmt.go new file mode 100644 index 00000000..8a21541a --- /dev/null +++ b/core/stmt.go @@ -0,0 +1,166 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" + "errors" + "reflect" +) + +// Stmt reprents a stmt objects +type Stmt struct { + *sql.Stmt + db *DB + names map[string]int +} + +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i += 1 + return "?" + }) + + stmt, err := db.DB.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{stmt, db, names}, nil +} + +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + return s.Stmt.ExecContext(ctx, args...) +} + +func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) { + return s.ExecMapContext(context.Background(), mp) +} + +func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + return s.Stmt.ExecContext(ctx, args...) +} + +func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) { + return s.ExecStructContext(context.Background(), st) +} + +func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { + rows, err := s.Stmt.QueryContext(ctx, args...) + if err != nil { + return nil, err + } + return &Rows{rows, s.db}, nil +} + +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + + return s.QueryContext(ctx, args...) +} + +func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) { + return s.QueryMapContext(context.Background(), mp) +} + +func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return nil, errors.New("mp should be a map's pointer") + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + + return s.Query(args...) +} + +func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { + return s.QueryStructContext(context.Background(), st) +} + +func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { + rows, err := s.QueryContext(ctx, args...) + return &Row{rows, err} +} + +func (s *Stmt) QueryRow(args ...interface{}) *Row { + return s.QueryRowContext(context.Background(), args...) +} + +func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row { + vv := reflect.ValueOf(mp) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { + return &Row{nil, errors.New("mp should be a map's pointer")} + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface() + } + + return s.QueryRowContext(ctx, args...) +} + +func (s *Stmt) QueryRowMap(mp interface{}) *Row { + return s.QueryRowMapContext(context.Background(), mp) +} + +func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row { + vv := reflect.ValueOf(st) + if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { + return &Row{nil, errors.New("st should be a struct's pointer")} + } + + args := make([]interface{}, len(s.names)) + for k, i := range s.names { + args[i] = vv.Elem().FieldByName(k).Interface() + } + + return s.QueryRowContext(ctx, args...) +} + +func (s *Stmt) QueryRowStruct(st interface{}) *Row { + return s.QueryRowStructContext(context.Background(), st) +} diff --git a/core/table.go b/core/table.go new file mode 100644 index 00000000..0a3889e1 --- /dev/null +++ b/core/table.go @@ -0,0 +1,155 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "reflect" + "strings" +) + +// Table represents a database table +type Table struct { + Name string + Type reflect.Type + columnsSeq []string + columnsMap map[string][]*Column + columns []*Column + Indexes map[string]*Index + PrimaryKeys []string + AutoIncrement string + Created map[string]bool + Updated string + Deleted string + Version string + Cacher Cacher + StoreEngine string + Charset string + Comment string +} + +func (table *Table) Columns() []*Column { + return table.columns +} + +func (table *Table) ColumnsSeq() []string { + return table.columnsSeq +} + +func NewEmptyTable() *Table { + return NewTable("", nil) +} + +// NewTable creates a new Table object +func NewTable(name string, t reflect.Type) *Table { + return &Table{Name: name, Type: t, + columnsSeq: make([]string, 0), + columns: make([]*Column, 0), + columnsMap: make(map[string][]*Column), + Indexes: make(map[string]*Index), + Created: make(map[string]bool), + PrimaryKeys: make([]string, 0), + } +} + +func (table *Table) columnsByName(name string) []*Column { + n := len(name) + + for k := range table.columnsMap { + if len(k) != n { + continue + } + if strings.EqualFold(k, name) { + return table.columnsMap[k] + } + } + return nil +} + +func (table *Table) GetColumn(name string) *Column { + + cols := table.columnsByName(name) + + if cols != nil { + return cols[0] + } + + return nil +} + +func (table *Table) GetColumnIdx(name string, idx int) *Column { + cols := table.columnsByName(name) + + if cols != nil && idx < len(cols) { + return cols[idx] + } + + return nil +} + +// PKColumns reprents all primary key columns +func (table *Table) PKColumns() []*Column { + columns := make([]*Column, len(table.PrimaryKeys)) + for i, name := range table.PrimaryKeys { + columns[i] = table.GetColumn(name) + } + return columns +} + +func (table *Table) ColumnType(name string) reflect.Type { + t, _ := table.Type.FieldByName(name) + return t.Type +} + +func (table *Table) AutoIncrColumn() *Column { + return table.GetColumn(table.AutoIncrement) +} + +func (table *Table) VersionColumn() *Column { + return table.GetColumn(table.Version) +} + +func (table *Table) UpdatedColumn() *Column { + return table.GetColumn(table.Updated) +} + +func (table *Table) DeletedColumn() *Column { + return table.GetColumn(table.Deleted) +} + +// AddColumn adds a column to table +func (table *Table) AddColumn(col *Column) { + table.columnsSeq = append(table.columnsSeq, col.Name) + table.columns = append(table.columns, col) + colName := strings.ToLower(col.Name) + if c, ok := table.columnsMap[colName]; ok { + table.columnsMap[colName] = append(c, col) + } else { + table.columnsMap[colName] = []*Column{col} + } + + if col.IsPrimaryKey { + table.PrimaryKeys = append(table.PrimaryKeys, col.Name) + } + if col.IsAutoIncrement { + table.AutoIncrement = col.Name + } + if col.IsCreated { + table.Created[col.Name] = true + } + if col.IsUpdated { + table.Updated = col.Name + } + if col.IsDeleted { + table.Deleted = col.Name + } + if col.IsVersion { + table.Version = col.Name + } +} + +// AddIndex adds an index or an unique to table +func (table *Table) AddIndex(index *Index) { + table.Indexes[index.Name] = index +} diff --git a/core/table_test.go b/core/table_test.go new file mode 100644 index 00000000..ea31afbc --- /dev/null +++ b/core/table_test.go @@ -0,0 +1,111 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "strings" + "testing" +) + +var testsGetColumn = []struct { + name string + idx int +}{ + {"Id", 0}, + {"Deleted", 0}, + {"Caption", 0}, + {"Code_1", 0}, + {"Code_2", 0}, + {"Code_3", 0}, + {"Parent_Id", 0}, + {"Latitude", 0}, + {"Longitude", 0}, +} + +var table *Table + +func init() { + + table = NewEmptyTable() + + var name string + + for i := 0; i < len(testsGetColumn); i++ { + // as in Table.AddColumn func + name = strings.ToLower(testsGetColumn[i].name) + + table.columnsMap[name] = append(table.columnsMap[name], &Column{}) + } +} + +func TestGetColumn(t *testing.T) { + + for _, test := range testsGetColumn { + if table.GetColumn(test.name) == nil { + t.Error("Column not found!") + } + } +} + +func TestGetColumnIdx(t *testing.T) { + + for _, test := range testsGetColumn { + if table.GetColumnIdx(test.name, test.idx) == nil { + t.Errorf("Column %s with idx %d not found!", test.name, test.idx) + } + } +} + +func BenchmarkGetColumnWithToLower(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + + if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok { + b.Errorf("Column not found:%s", test.name) + } + } + } +} + +func BenchmarkGetColumnIdxWithToLower(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + + if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok { + if test.idx < len(c) { + continue + } else { + b.Errorf("Bad idx in: %s, %d", test.name, test.idx) + } + } else { + b.Errorf("Column not found: %s, %d", test.name, test.idx) + } + } + } +} + +func BenchmarkGetColumn(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + if table.GetColumn(test.name) == nil { + b.Errorf("Column not found:%s", test.name) + } + } + } +} + +func BenchmarkGetColumnIdx(b *testing.B) { + + for i := 0; i < b.N; i++ { + for _, test := range testsGetColumn { + if table.GetColumnIdx(test.name, test.idx) == nil { + b.Errorf("Column not found:%s, %d", test.name, test.idx) + } + } + } +} diff --git a/core/tx.go b/core/tx.go new file mode 100644 index 00000000..a56b7006 --- /dev/null +++ b/core/tx.go @@ -0,0 +1,153 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "context" + "database/sql" +) + +type Tx struct { + *sql.Tx + db *DB +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.DB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &Tx{tx, db}, nil +} + +func (db *DB) Begin() (*Tx, error) { + tx, err := db.DB.Begin() + if err != nil { + return nil, err + } + return &Tx{tx, db}, nil +} + +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + names := make(map[string]int) + var i int + query = re.ReplaceAllStringFunc(query, func(src string) string { + names[src[1:]] = i + i += 1 + return "?" + }) + + stmt, err := tx.Tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &Stmt{stmt, tx.db, names}, nil +} + +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + stmt.Stmt = tx.Tx.StmtContext(ctx, stmt.Stmt) + return stmt +} + +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +func (tx *Tx) ExecMapContext(ctx context.Context, query string, mp interface{}) (sql.Result, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.Tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) { + return tx.ExecMapContext(context.Background(), query, mp) +} + +func (tx *Tx) ExecStructContext(ctx context.Context, query string, st interface{}) (sql.Result, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.Tx.ExecContext(ctx, query, args...) +} + +func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) { + return tx.ExecStructContext(context.Background(), query, st) +} + +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + rows, err := tx.Tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &Rows{rows, tx.db}, nil +} + +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +func (tx *Tx) QueryMapContext(ctx context.Context, query string, mp interface{}) (*Rows, error) { + query, args, err := MapToSlice(query, mp) + if err != nil { + return nil, err + } + return tx.QueryContext(ctx, query, args...) +} + +func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) { + return tx.QueryMapContext(context.Background(), query, mp) +} + +func (tx *Tx) QueryStructContext(ctx context.Context, query string, st interface{}) (*Rows, error) { + query, args, err := StructToSlice(query, st) + if err != nil { + return nil, err + } + return tx.QueryContext(ctx, query, args...) +} + +func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { + return tx.QueryStructContext(context.Background(), query, st) +} + +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows, err} +} + +func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { + return tx.QueryRowContext(context.Background(), query, args...) +} + +func (tx *Tx) QueryRowMapContext(ctx context.Context, query string, mp interface{}) *Row { + query, args, err := MapToSlice(query, mp) + if err != nil { + return &Row{nil, err} + } + return tx.QueryRowContext(ctx, query, args...) +} + +func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { + return tx.QueryRowMapContext(context.Background(), query, mp) +} + +func (tx *Tx) QueryRowStructContext(ctx context.Context, query string, st interface{}) *Row { + query, args, err := StructToSlice(query, st) + if err != nil { + return &Row{nil, err} + } + return tx.QueryRowContext(ctx, query, args...) +} + +func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { + return tx.QueryRowStructContext(context.Background(), query, st) +} diff --git a/core/type.go b/core/type.go new file mode 100644 index 00000000..75b6c363 --- /dev/null +++ b/core/type.go @@ -0,0 +1,325 @@ +// Copyright 2019 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package core + +import ( + "reflect" + "sort" + "strings" + "time" +) + +const ( + POSTGRES = "postgres" + SQLITE = "sqlite3" + MYSQL = "mysql" + MSSQL = "mssql" + ORACLE = "oracle" +) + +// xorm SQL types +type SQLType struct { + Name string + DefaultLength int + DefaultLength2 int +} + +const ( + UNKNOW_TYPE = iota + TEXT_TYPE + BLOB_TYPE + TIME_TYPE + NUMERIC_TYPE +) + +func (s *SQLType) IsType(st int) bool { + if t, ok := SqlTypes[s.Name]; ok && t == st { + return true + } + return false +} + +func (s *SQLType) IsText() bool { + return s.IsType(TEXT_TYPE) +} + +func (s *SQLType) IsBlob() bool { + return s.IsType(BLOB_TYPE) +} + +func (s *SQLType) IsTime() bool { + return s.IsType(TIME_TYPE) +} + +func (s *SQLType) IsNumeric() bool { + return s.IsType(NUMERIC_TYPE) +} + +func (s *SQLType) IsJson() bool { + return s.Name == Json || s.Name == Jsonb +} + +var ( + Bit = "BIT" + TinyInt = "TINYINT" + SmallInt = "SMALLINT" + MediumInt = "MEDIUMINT" + Int = "INT" + Integer = "INTEGER" + BigInt = "BIGINT" + + Enum = "ENUM" + Set = "SET" + + Char = "CHAR" + Varchar = "VARCHAR" + NChar = "NCHAR" + NVarchar = "NVARCHAR" + TinyText = "TINYTEXT" + Text = "TEXT" + NText = "NTEXT" + Clob = "CLOB" + MediumText = "MEDIUMTEXT" + LongText = "LONGTEXT" + Uuid = "UUID" + UniqueIdentifier = "UNIQUEIDENTIFIER" + SysName = "SYSNAME" + + Date = "DATE" + DateTime = "DATETIME" + SmallDateTime = "SMALLDATETIME" + Time = "TIME" + TimeStamp = "TIMESTAMP" + TimeStampz = "TIMESTAMPZ" + Year = "YEAR" + + Decimal = "DECIMAL" + Numeric = "NUMERIC" + Money = "MONEY" + SmallMoney = "SMALLMONEY" + + Real = "REAL" + Float = "FLOAT" + Double = "DOUBLE" + + Binary = "BINARY" + VarBinary = "VARBINARY" + TinyBlob = "TINYBLOB" + Blob = "BLOB" + MediumBlob = "MEDIUMBLOB" + LongBlob = "LONGBLOB" + Bytea = "BYTEA" + + Bool = "BOOL" + Boolean = "BOOLEAN" + + Serial = "SERIAL" + BigSerial = "BIGSERIAL" + + Json = "JSON" + Jsonb = "JSONB" + + SqlTypes = map[string]int{ + Bit: NUMERIC_TYPE, + TinyInt: NUMERIC_TYPE, + SmallInt: NUMERIC_TYPE, + MediumInt: NUMERIC_TYPE, + Int: NUMERIC_TYPE, + Integer: NUMERIC_TYPE, + BigInt: NUMERIC_TYPE, + + Enum: TEXT_TYPE, + Set: TEXT_TYPE, + Json: TEXT_TYPE, + Jsonb: TEXT_TYPE, + + Char: TEXT_TYPE, + NChar: TEXT_TYPE, + Varchar: TEXT_TYPE, + NVarchar: TEXT_TYPE, + TinyText: TEXT_TYPE, + Text: TEXT_TYPE, + NText: TEXT_TYPE, + MediumText: TEXT_TYPE, + LongText: TEXT_TYPE, + Uuid: TEXT_TYPE, + Clob: TEXT_TYPE, + SysName: TEXT_TYPE, + + Date: TIME_TYPE, + DateTime: TIME_TYPE, + Time: TIME_TYPE, + TimeStamp: TIME_TYPE, + TimeStampz: TIME_TYPE, + SmallDateTime: TIME_TYPE, + Year: TIME_TYPE, + + Decimal: NUMERIC_TYPE, + Numeric: NUMERIC_TYPE, + Real: NUMERIC_TYPE, + Float: NUMERIC_TYPE, + Double: NUMERIC_TYPE, + Money: NUMERIC_TYPE, + SmallMoney: NUMERIC_TYPE, + + Binary: BLOB_TYPE, + VarBinary: BLOB_TYPE, + + TinyBlob: BLOB_TYPE, + Blob: BLOB_TYPE, + MediumBlob: BLOB_TYPE, + LongBlob: BLOB_TYPE, + Bytea: BLOB_TYPE, + UniqueIdentifier: BLOB_TYPE, + + Bool: NUMERIC_TYPE, + + Serial: NUMERIC_TYPE, + BigSerial: NUMERIC_TYPE, + } + + intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"} + uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"} +) + +// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparison +var ( + c_EMPTY_STRING string + c_BOOL_DEFAULT bool + c_BYTE_DEFAULT byte + c_COMPLEX64_DEFAULT complex64 + c_COMPLEX128_DEFAULT complex128 + c_FLOAT32_DEFAULT float32 + c_FLOAT64_DEFAULT float64 + c_INT64_DEFAULT int64 + c_UINT64_DEFAULT uint64 + c_INT32_DEFAULT int32 + c_UINT32_DEFAULT uint32 + c_INT16_DEFAULT int16 + c_UINT16_DEFAULT uint16 + c_INT8_DEFAULT int8 + c_UINT8_DEFAULT uint8 + c_INT_DEFAULT int + c_UINT_DEFAULT uint + c_TIME_DEFAULT time.Time +) + +var ( + IntType = reflect.TypeOf(c_INT_DEFAULT) + Int8Type = reflect.TypeOf(c_INT8_DEFAULT) + Int16Type = reflect.TypeOf(c_INT16_DEFAULT) + Int32Type = reflect.TypeOf(c_INT32_DEFAULT) + Int64Type = reflect.TypeOf(c_INT64_DEFAULT) + + UintType = reflect.TypeOf(c_UINT_DEFAULT) + Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT) + Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT) + Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT) + Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT) + + Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT) + Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT) + + Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT) + Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT) + + StringType = reflect.TypeOf(c_EMPTY_STRING) + BoolType = reflect.TypeOf(c_BOOL_DEFAULT) + ByteType = reflect.TypeOf(c_BYTE_DEFAULT) + BytesType = reflect.SliceOf(ByteType) + + TimeType = reflect.TypeOf(c_TIME_DEFAULT) +) + +var ( + PtrIntType = reflect.PtrTo(IntType) + PtrInt8Type = reflect.PtrTo(Int8Type) + PtrInt16Type = reflect.PtrTo(Int16Type) + PtrInt32Type = reflect.PtrTo(Int32Type) + PtrInt64Type = reflect.PtrTo(Int64Type) + + PtrUintType = reflect.PtrTo(UintType) + PtrUint8Type = reflect.PtrTo(Uint8Type) + PtrUint16Type = reflect.PtrTo(Uint16Type) + PtrUint32Type = reflect.PtrTo(Uint32Type) + PtrUint64Type = reflect.PtrTo(Uint64Type) + + PtrFloat32Type = reflect.PtrTo(Float32Type) + PtrFloat64Type = reflect.PtrTo(Float64Type) + + PtrComplex64Type = reflect.PtrTo(Complex64Type) + PtrComplex128Type = reflect.PtrTo(Complex128Type) + + PtrStringType = reflect.PtrTo(StringType) + PtrBoolType = reflect.PtrTo(BoolType) + PtrByteType = reflect.PtrTo(ByteType) + + PtrTimeType = reflect.PtrTo(TimeType) +) + +// Type2SQLType generate SQLType acorrding Go's type +func Type2SQLType(t reflect.Type) (st SQLType) { + switch k := t.Kind(); k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + st = SQLType{Int, 0, 0} + case reflect.Int64, reflect.Uint64: + st = SQLType{BigInt, 0, 0} + case reflect.Float32: + st = SQLType{Float, 0, 0} + case reflect.Float64: + st = SQLType{Double, 0, 0} + case reflect.Complex64, reflect.Complex128: + st = SQLType{Varchar, 64, 0} + case reflect.Array, reflect.Slice, reflect.Map: + if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) { + st = SQLType{Blob, 0, 0} + } else { + st = SQLType{Text, 0, 0} + } + case reflect.Bool: + st = SQLType{Bool, 0, 0} + case reflect.String: + st = SQLType{Varchar, 255, 0} + case reflect.Struct: + if t.ConvertibleTo(TimeType) { + st = SQLType{DateTime, 0, 0} + } else { + // TODO need to handle association struct + st = SQLType{Text, 0, 0} + } + case reflect.Ptr: + st = Type2SQLType(t.Elem()) + default: + st = SQLType{Text, 0, 0} + } + return +} + +// default sql type change to go types +func SQLType2Type(st SQLType) reflect.Type { + name := strings.ToUpper(st.Name) + switch name { + case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial: + return reflect.TypeOf(1) + case BigInt, BigSerial: + return reflect.TypeOf(int64(1)) + case Float, Real: + return reflect.TypeOf(float32(1)) + case Double: + return reflect.TypeOf(float64(1)) + case Char, NChar, Varchar, NVarchar, TinyText, Text, NText, MediumText, LongText, Enum, Set, Uuid, Clob, SysName: + return reflect.TypeOf("") + case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary, UniqueIdentifier: + return reflect.TypeOf([]byte{}) + case Bool: + return reflect.TypeOf(true) + case DateTime, Date, Time, TimeStamp, TimeStampz, SmallDateTime, Year: + return reflect.TypeOf(c_TIME_DEFAULT) + case Decimal, Numeric, Money, SmallMoney: + return reflect.TypeOf("") + default: + return reflect.TypeOf("") + } +}