don't ignore vendor folder
This commit is contained in:
parent
0691bf9b03
commit
20a684b85a
|
|
@ -21,7 +21,6 @@ _cgo_export.*
|
|||
_testmain.go
|
||||
|
||||
*.exe
|
||||
vendor
|
||||
|
||||
*.log
|
||||
.vendor
|
||||
|
|
|
|||
|
|
@ -46,17 +46,17 @@ func TestSum(t *testing.T) {
|
|||
colInt := testEngine.ColumnMapper.Obj2Table("Int")
|
||||
colFloat := testEngine.ColumnMapper.Obj2Table("Float")
|
||||
|
||||
sumInt, err := testEngine.Sum(new(SumStruct), "`"+colInt+"`")
|
||||
sumInt, err := testEngine.Sum(new(SumStruct), colInt)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, int(sumInt), i)
|
||||
|
||||
sumFloat, err := testEngine.Sum(new(SumStruct), "`"+colFloat+"`")
|
||||
sumFloat, err := testEngine.Sum(new(SumStruct), colFloat)
|
||||
assert.NoError(t, err)
|
||||
assert.Condition(t, func() bool {
|
||||
return isFloatEq(sumFloat, float64(f), 2)
|
||||
})
|
||||
|
||||
sums, err := testEngine.Sums(new(SumStruct), "`"+colInt+"`", "`"+colFloat+"`")
|
||||
sums, err := testEngine.Sums(new(SumStruct), colInt, colFloat)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 2, len(sums))
|
||||
assert.EqualValues(t, i, int(sums[0]))
|
||||
|
|
@ -64,7 +64,7 @@ func TestSum(t *testing.T) {
|
|||
return isFloatEq(sums[1], float64(f), 2)
|
||||
})
|
||||
|
||||
sumsInt, err := testEngine.SumsInt(new(SumStruct), "`"+colInt+"`")
|
||||
sumsInt, err := testEngine.SumsInt(new(SumStruct), colInt)
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 1, len(sumsInt))
|
||||
assert.EqualValues(t, i, int(sumsInt[0]))
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
Copyright (c) 2012-2013 Dave Collins <dave@davec.name>
|
||||
|
||||
Permission to use, copy, modify, and distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
|
@ -0,0 +1,346 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the internal
|
||||
// reflect.Value fields.
|
||||
var offsetPtr, offsetScalar, offsetFlag uintptr
|
||||
|
||||
// reflectValueOld mirrors the struct layout of the reflect package Value type
|
||||
// before golang commit ecccf07e7f9d.
|
||||
var reflectValueOld struct {
|
||||
typ unsafe.Pointer
|
||||
val unsafe.Pointer
|
||||
flag uintptr
|
||||
}
|
||||
|
||||
// reflectValueNew mirrors the struct layout of the reflect package Value type
|
||||
// after golang commit ecccf07e7f9d.
|
||||
var reflectValueNew struct {
|
||||
typ unsafe.Pointer
|
||||
ptr unsafe.Pointer
|
||||
scalar uintptr
|
||||
flag uintptr
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Older versions of reflect.Value stored small integers directly in the
|
||||
// ptr field (which is named val in the older versions). Newer versions
|
||||
// added a new field named scalar for this purpose which unfortuantely
|
||||
// comes before the flag field. Further the new field is before the
|
||||
// flag field, so the offset of the flag field is different as well.
|
||||
// This code constructs a new reflect.Value from a known small integer
|
||||
// and checks if the val field within it matches. When it matches, the
|
||||
// old style reflect.Value is being used. Otherwise it's the new style.
|
||||
v := 0xf00
|
||||
vv := reflect.ValueOf(v)
|
||||
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) +
|
||||
unsafe.Offsetof(reflectValueOld.val))
|
||||
|
||||
// Assume the old style by default.
|
||||
offsetPtr = unsafe.Offsetof(reflectValueOld.val)
|
||||
offsetScalar = 0
|
||||
offsetFlag = unsafe.Offsetof(reflectValueOld.flag)
|
||||
|
||||
// Use the new style offsets if the ptr field doesn't match the value
|
||||
// since it must be in the new scalar field.
|
||||
if int(*(*uintptr)(upv)) != v {
|
||||
offsetPtr = unsafe.Offsetof(reflectValueNew.ptr)
|
||||
offsetScalar = unsafe.Offsetof(reflectValueNew.scalar)
|
||||
offsetFlag = unsafe.Offsetof(reflectValueNew.flag)
|
||||
}
|
||||
}
|
||||
|
||||
// flagIndir indicates whether the value field of a reflect.Value is the actual
|
||||
// data or a pointer to the data.
|
||||
const flagIndir = 1 << 1
|
||||
|
||||
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
|
||||
// the typical safety restrictions preventing access to unaddressable and
|
||||
// unexported data. It works by digging the raw pointer to the underlying
|
||||
// value out of the protected value and generating a new unprotected (unsafe)
|
||||
// reflect.Value to it.
|
||||
//
|
||||
// This allows us to check for implementations of the Stringer and error
|
||||
// interfaces to be used for pretty printing ordinarily unaddressable and
|
||||
// inaccessible values such as unexported struct fields.
|
||||
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
|
||||
indirects := 1
|
||||
vt := v.Type()
|
||||
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
|
||||
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
|
||||
if rvf&flagIndir != 0 {
|
||||
vt = reflect.PtrTo(v.Type())
|
||||
indirects++
|
||||
} else if offsetScalar != 0 {
|
||||
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetScalar)
|
||||
}
|
||||
|
||||
pv := reflect.NewAt(vt, upv)
|
||||
rv = pv
|
||||
for i := 0; i < indirects; i++ {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
return rv
|
||||
}
|
||||
|
||||
// Some constants in the form of bytes to avoid string overhead. This mirrors
|
||||
// the technique used in the fmt package.
|
||||
var (
|
||||
panicBytes = []byte("(PANIC=")
|
||||
plusBytes = []byte("+")
|
||||
iBytes = []byte("i")
|
||||
trueBytes = []byte("true")
|
||||
falseBytes = []byte("false")
|
||||
interfaceBytes = []byte("(interface {})")
|
||||
commaNewlineBytes = []byte(",\n")
|
||||
newlineBytes = []byte("\n")
|
||||
openBraceBytes = []byte("{")
|
||||
openBraceNewlineBytes = []byte("{\n")
|
||||
closeBraceBytes = []byte("}")
|
||||
asteriskBytes = []byte("*")
|
||||
colonBytes = []byte(":")
|
||||
colonSpaceBytes = []byte(": ")
|
||||
openParenBytes = []byte("(")
|
||||
closeParenBytes = []byte(")")
|
||||
spaceBytes = []byte(" ")
|
||||
pointerChainBytes = []byte("->")
|
||||
nilAngleBytes = []byte("<nil>")
|
||||
maxNewlineBytes = []byte("<max depth reached>\n")
|
||||
maxShortBytes = []byte("<max>")
|
||||
circularBytes = []byte("<already shown>")
|
||||
circularShortBytes = []byte("<shown>")
|
||||
invalidAngleBytes = []byte("<invalid>")
|
||||
openBracketBytes = []byte("[")
|
||||
closeBracketBytes = []byte("]")
|
||||
percentBytes = []byte("%")
|
||||
precisionBytes = []byte(".")
|
||||
openAngleBytes = []byte("<")
|
||||
closeAngleBytes = []byte(">")
|
||||
openMapBytes = []byte("map[")
|
||||
closeMapBytes = []byte("]")
|
||||
lenEqualsBytes = []byte("len=")
|
||||
capEqualsBytes = []byte("cap=")
|
||||
)
|
||||
|
||||
// hexDigits is used to map a decimal value to a hex digit.
|
||||
var hexDigits = "0123456789abcdef"
|
||||
|
||||
// catchPanic handles any panics that might occur during the handleMethods
|
||||
// calls.
|
||||
func catchPanic(w io.Writer, v reflect.Value) {
|
||||
if err := recover(); err != nil {
|
||||
w.Write(panicBytes)
|
||||
fmt.Fprintf(w, "%v", err)
|
||||
w.Write(closeParenBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// handleMethods attempts to call the Error and String methods on the underlying
|
||||
// type the passed reflect.Value represents and outputes the result to Writer w.
|
||||
//
|
||||
// It handles panics in any called methods by catching and displaying the error
|
||||
// as the formatted value.
|
||||
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
|
||||
// We need an interface to check if the type implements the error or
|
||||
// Stringer interface. However, the reflect package won't give us an
|
||||
// interface on certain things like unexported struct fields in order
|
||||
// to enforce visibility rules. We use unsafe to bypass these restrictions
|
||||
// since this package does not mutate the values.
|
||||
if !v.CanInterface() {
|
||||
v = unsafeReflectValue(v)
|
||||
}
|
||||
|
||||
// Choose whether or not to do error and Stringer interface lookups against
|
||||
// the base type or a pointer to the base type depending on settings.
|
||||
// Technically calling one of these methods with a pointer receiver can
|
||||
// mutate the value, however, types which choose to satisify an error or
|
||||
// Stringer interface with a pointer receiver should not be mutating their
|
||||
// state inside these interface methods.
|
||||
var viface interface{}
|
||||
if !cs.DisablePointerMethods {
|
||||
if !v.CanAddr() {
|
||||
v = unsafeReflectValue(v)
|
||||
}
|
||||
viface = v.Addr().Interface()
|
||||
} else {
|
||||
if v.CanAddr() {
|
||||
v = v.Addr()
|
||||
}
|
||||
viface = v.Interface()
|
||||
}
|
||||
|
||||
// Is it an error or Stringer?
|
||||
switch iface := viface.(type) {
|
||||
case error:
|
||||
defer catchPanic(w, v)
|
||||
if cs.ContinueOnMethod {
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(iface.Error()))
|
||||
w.Write(closeParenBytes)
|
||||
w.Write(spaceBytes)
|
||||
return false
|
||||
}
|
||||
|
||||
w.Write([]byte(iface.Error()))
|
||||
return true
|
||||
|
||||
case fmt.Stringer:
|
||||
defer catchPanic(w, v)
|
||||
if cs.ContinueOnMethod {
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(iface.String()))
|
||||
w.Write(closeParenBytes)
|
||||
w.Write(spaceBytes)
|
||||
return false
|
||||
}
|
||||
w.Write([]byte(iface.String()))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// printBool outputs a boolean value as true or false to Writer w.
|
||||
func printBool(w io.Writer, val bool) {
|
||||
if val {
|
||||
w.Write(trueBytes)
|
||||
} else {
|
||||
w.Write(falseBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// printInt outputs a signed integer value to Writer w.
|
||||
func printInt(w io.Writer, val int64, base int) {
|
||||
w.Write([]byte(strconv.FormatInt(val, base)))
|
||||
}
|
||||
|
||||
// printUint outputs an unsigned integer value to Writer w.
|
||||
func printUint(w io.Writer, val uint64, base int) {
|
||||
w.Write([]byte(strconv.FormatUint(val, base)))
|
||||
}
|
||||
|
||||
// printFloat outputs a floating point value using the specified precision,
|
||||
// which is expected to be 32 or 64bit, to Writer w.
|
||||
func printFloat(w io.Writer, val float64, precision int) {
|
||||
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
|
||||
}
|
||||
|
||||
// printComplex outputs a complex value using the specified float precision
|
||||
// for the real and imaginary parts to Writer w.
|
||||
func printComplex(w io.Writer, c complex128, floatPrecision int) {
|
||||
r := real(c)
|
||||
w.Write(openParenBytes)
|
||||
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
|
||||
i := imag(c)
|
||||
if i >= 0 {
|
||||
w.Write(plusBytes)
|
||||
}
|
||||
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
|
||||
w.Write(iBytes)
|
||||
w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
|
||||
// prefix to Writer w.
|
||||
func printHexPtr(w io.Writer, p uintptr) {
|
||||
// Null pointer.
|
||||
num := uint64(p)
|
||||
if num == 0 {
|
||||
w.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
|
||||
buf := make([]byte, 18)
|
||||
|
||||
// It's simpler to construct the hex string right to left.
|
||||
base := uint64(16)
|
||||
i := len(buf) - 1
|
||||
for num >= base {
|
||||
buf[i] = hexDigits[num%base]
|
||||
num /= base
|
||||
i--
|
||||
}
|
||||
buf[i] = hexDigits[num]
|
||||
|
||||
// Add '0x' prefix.
|
||||
i--
|
||||
buf[i] = 'x'
|
||||
i--
|
||||
buf[i] = '0'
|
||||
|
||||
// Strip unused leading bytes.
|
||||
buf = buf[i:]
|
||||
w.Write(buf)
|
||||
}
|
||||
|
||||
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
|
||||
// elements to be sorted.
|
||||
type valuesSorter struct {
|
||||
values []reflect.Value
|
||||
}
|
||||
|
||||
// Len returns the number of values in the slice. It is part of the
|
||||
// sort.Interface implementation.
|
||||
func (s *valuesSorter) Len() int {
|
||||
return len(s.values)
|
||||
}
|
||||
|
||||
// Swap swaps the values at the passed indices. It is part of the
|
||||
// sort.Interface implementation.
|
||||
func (s *valuesSorter) Swap(i, j int) {
|
||||
s.values[i], s.values[j] = s.values[j], s.values[i]
|
||||
}
|
||||
|
||||
// Less returns whether the value at index i should sort before the
|
||||
// value at index j. It is part of the sort.Interface implementation.
|
||||
func (s *valuesSorter) Less(i, j int) bool {
|
||||
switch s.values[i].Kind() {
|
||||
case reflect.Bool:
|
||||
return !s.values[i].Bool() && s.values[j].Bool()
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return s.values[i].Int() < s.values[j].Int()
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return s.values[i].Uint() < s.values[j].Uint()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return s.values[i].Float() < s.values[j].Float()
|
||||
case reflect.String:
|
||||
return s.values[i].String() < s.values[j].String()
|
||||
case reflect.Uintptr:
|
||||
return s.values[i].Uint() < s.values[j].Uint()
|
||||
}
|
||||
return s.values[i].String() < s.values[j].String()
|
||||
}
|
||||
|
||||
// sortValues is a generic sort function for native types: int, uint, bool,
|
||||
// string and uintptr. Other inputs are sorted according to their
|
||||
// Value.String() value to ensure display stability.
|
||||
func sortValues(values []reflect.Value) {
|
||||
if len(values) == 0 {
|
||||
return
|
||||
}
|
||||
sort.Sort(&valuesSorter{values})
|
||||
}
|
||||
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// custom type to test Stinger interface on non-pointer receiver.
|
||||
type stringer string
|
||||
|
||||
// String implements the Stringer interface for testing invocation of custom
|
||||
// stringers on types with non-pointer receivers.
|
||||
func (s stringer) String() string {
|
||||
return "stringer " + string(s)
|
||||
}
|
||||
|
||||
// custom type to test Stinger interface on pointer receiver.
|
||||
type pstringer string
|
||||
|
||||
// String implements the Stringer interface for testing invocation of custom
|
||||
// stringers on types with only pointer receivers.
|
||||
func (s *pstringer) String() string {
|
||||
return "stringer " + string(*s)
|
||||
}
|
||||
|
||||
// xref1 and xref2 are cross referencing structs for testing circular reference
|
||||
// detection.
|
||||
type xref1 struct {
|
||||
ps2 *xref2
|
||||
}
|
||||
type xref2 struct {
|
||||
ps1 *xref1
|
||||
}
|
||||
|
||||
// indirCir1, indirCir2, and indirCir3 are used to generate an indirect circular
|
||||
// reference for testing detection.
|
||||
type indirCir1 struct {
|
||||
ps2 *indirCir2
|
||||
}
|
||||
type indirCir2 struct {
|
||||
ps3 *indirCir3
|
||||
}
|
||||
type indirCir3 struct {
|
||||
ps1 *indirCir1
|
||||
}
|
||||
|
||||
// embed is used to test embedded structures.
|
||||
type embed struct {
|
||||
a string
|
||||
}
|
||||
|
||||
// embedwrap is used to test embedded structures.
|
||||
type embedwrap struct {
|
||||
*embed
|
||||
e *embed
|
||||
}
|
||||
|
||||
// panicer is used to intentionally cause a panic for testing spew properly
|
||||
// handles them
|
||||
type panicer int
|
||||
|
||||
func (p panicer) String() string {
|
||||
panic("test panic")
|
||||
}
|
||||
|
||||
// customError is used to test custom error interface invocation.
|
||||
type customError int
|
||||
|
||||
func (e customError) Error() string {
|
||||
return fmt.Sprintf("error: %d", int(e))
|
||||
}
|
||||
|
||||
// stringizeWants converts a slice of wanted test output into a format suitable
|
||||
// for a test error message.
|
||||
func stringizeWants(wants []string) string {
|
||||
s := ""
|
||||
for i, want := range wants {
|
||||
if i > 0 {
|
||||
s += fmt.Sprintf("want%d: %s", i+1, want)
|
||||
} else {
|
||||
s += "want: " + want
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// testFailed returns whether or not a test failed by checking if the result
|
||||
// of the test is in the slice of wanted strings.
|
||||
func testFailed(result string, wants []string) bool {
|
||||
for _, want := range wants {
|
||||
if result == want {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// TestSortValues ensures the sort functionality for relect.Value based sorting
|
||||
// works as intended.
|
||||
func TestSortValues(t *testing.T) {
|
||||
getInterfaces := func(values []reflect.Value) []interface{} {
|
||||
interfaces := []interface{}{}
|
||||
for _, v := range values {
|
||||
interfaces = append(interfaces, v.Interface())
|
||||
}
|
||||
return interfaces
|
||||
}
|
||||
|
||||
v := reflect.ValueOf
|
||||
|
||||
a := v("a")
|
||||
b := v("b")
|
||||
c := v("c")
|
||||
embedA := v(embed{"a"})
|
||||
embedB := v(embed{"b"})
|
||||
embedC := v(embed{"c"})
|
||||
tests := []struct {
|
||||
input []reflect.Value
|
||||
expected []reflect.Value
|
||||
}{
|
||||
// No values.
|
||||
{
|
||||
[]reflect.Value{},
|
||||
[]reflect.Value{},
|
||||
},
|
||||
// Bools.
|
||||
{
|
||||
[]reflect.Value{v(false), v(true), v(false)},
|
||||
[]reflect.Value{v(false), v(false), v(true)},
|
||||
},
|
||||
// Ints.
|
||||
{
|
||||
[]reflect.Value{v(2), v(1), v(3)},
|
||||
[]reflect.Value{v(1), v(2), v(3)},
|
||||
},
|
||||
// Uints.
|
||||
{
|
||||
[]reflect.Value{v(uint8(2)), v(uint8(1)), v(uint8(3))},
|
||||
[]reflect.Value{v(uint8(1)), v(uint8(2)), v(uint8(3))},
|
||||
},
|
||||
// Floats.
|
||||
{
|
||||
[]reflect.Value{v(2.0), v(1.0), v(3.0)},
|
||||
[]reflect.Value{v(1.0), v(2.0), v(3.0)},
|
||||
},
|
||||
// Strings.
|
||||
{
|
||||
[]reflect.Value{b, a, c},
|
||||
[]reflect.Value{a, b, c},
|
||||
},
|
||||
// Uintptrs.
|
||||
{
|
||||
[]reflect.Value{v(uintptr(2)), v(uintptr(1)), v(uintptr(3))},
|
||||
[]reflect.Value{v(uintptr(1)), v(uintptr(2)), v(uintptr(3))},
|
||||
},
|
||||
// Invalid.
|
||||
{
|
||||
[]reflect.Value{embedB, embedA, embedC},
|
||||
[]reflect.Value{embedB, embedA, embedC},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
spew.SortValues(test.input)
|
||||
// reflect.DeepEqual cannot really make sense of reflect.Value,
|
||||
// probably because of all the pointer tricks. For instance,
|
||||
// v(2.0) != v(2.0) on a 32-bits system. Turn them into interface{}
|
||||
// instead.
|
||||
input := getInterfaces(test.input)
|
||||
expected := getInterfaces(test.expected)
|
||||
if !reflect.DeepEqual(input, expected) {
|
||||
t.Errorf("Sort mismatch:\n %v != %v", input, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,288 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ConfigState houses the configuration options used by spew to format and
|
||||
// display values. There is a global instance, Config, that is used to control
|
||||
// all top-level Formatter and Dump functionality. Each ConfigState instance
|
||||
// provides methods equivalent to the top-level functions.
|
||||
//
|
||||
// The zero value for ConfigState provides no indentation. You would typically
|
||||
// want to set it to a space or a tab.
|
||||
//
|
||||
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
|
||||
// with default settings. See the documentation of NewDefaultConfig for default
|
||||
// values.
|
||||
type ConfigState struct {
|
||||
// Indent specifies the string to use for each indentation level. The
|
||||
// global config instance that all top-level functions use set this to a
|
||||
// single space by default. If you would like more indentation, you might
|
||||
// set this to a tab with "\t" or perhaps two spaces with " ".
|
||||
Indent string
|
||||
|
||||
// MaxDepth controls the maximum number of levels to descend into nested
|
||||
// data structures. The default, 0, means there is no limit.
|
||||
//
|
||||
// NOTE: Circular data structures are properly detected, so it is not
|
||||
// necessary to set this value unless you specifically want to limit deeply
|
||||
// nested data structures.
|
||||
MaxDepth int
|
||||
|
||||
// DisableMethods specifies whether or not error and Stringer interfaces are
|
||||
// invoked for types that implement them.
|
||||
DisableMethods bool
|
||||
|
||||
// DisablePointerMethods specifies whether or not to check for and invoke
|
||||
// error and Stringer interfaces on types which only accept a pointer
|
||||
// receiver when the current type is not a pointer.
|
||||
//
|
||||
// NOTE: This might be an unsafe action since calling one of these methods
|
||||
// with a pointer receiver could technically mutate the value, however,
|
||||
// in practice, types which choose to satisify an error or Stringer
|
||||
// interface with a pointer receiver should not be mutating their state
|
||||
// inside these interface methods.
|
||||
DisablePointerMethods bool
|
||||
|
||||
// ContinueOnMethod specifies whether or not recursion should continue once
|
||||
// a custom error or Stringer interface is invoked. The default, false,
|
||||
// means it will print the results of invoking the custom error or Stringer
|
||||
// interface and return immediately instead of continuing to recurse into
|
||||
// the internals of the data type.
|
||||
//
|
||||
// NOTE: This flag does not have any effect if method invocation is disabled
|
||||
// via the DisableMethods or DisablePointerMethods options.
|
||||
ContinueOnMethod bool
|
||||
|
||||
// SortKeys specifies map keys should be sorted before being printed. Use
|
||||
// this to have a more deterministic, diffable output. Note that only
|
||||
// native types (bool, int, uint, floats, uintptr and string) are supported
|
||||
// with other types sorted according to the reflect.Value.String() output
|
||||
// which guarantees display stability.
|
||||
SortKeys bool
|
||||
}
|
||||
|
||||
// Config is the active configuration of the top-level functions.
|
||||
// The configuration can be changed by modifying the contents of spew.Config.
|
||||
var Config = ConfigState{Indent: " "}
|
||||
|
||||
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the formatted string as a value that satisfies error. See NewFormatter
|
||||
// for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
|
||||
return fmt.Errorf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprint(w, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintf(w, format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||
// passed with a Formatter interface returned by c.NewFormatter. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintln(w, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
|
||||
return fmt.Print(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Printf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
|
||||
return fmt.Println(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprint(a ...interface{}) string {
|
||||
return fmt.Sprint(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||
// the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
|
||||
return fmt.Sprintf(format, c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||
// were passed with a Formatter interface returned by c.NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
|
||||
func (c *ConfigState) Sprintln(a ...interface{}) string {
|
||||
return fmt.Sprintln(c.convertArgs(a)...)
|
||||
}
|
||||
|
||||
/*
|
||||
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||
interface. As a result, it integrates cleanly with standard fmt package
|
||||
printing functions. The formatter is useful for inline printing of smaller data
|
||||
types similar to the standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Typically this function shouldn't be called directly. It is much easier to make
|
||||
use of the custom formatter by calling one of the convenience functions such as
|
||||
c.Printf, c.Println, or c.Printf.
|
||||
*/
|
||||
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
|
||||
return newFormatter(c, v)
|
||||
}
|
||||
|
||||
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||
// exactly the same as Dump.
|
||||
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
|
||||
fdump(c, w, a...)
|
||||
}
|
||||
|
||||
/*
|
||||
Dump displays the passed parameters to standard out with newlines, customizable
|
||||
indentation, and additional debug information such as complete types and all
|
||||
pointer addresses used to indirect to the final value. It provides the
|
||||
following features over the built-in printing facilities provided by the fmt
|
||||
package:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output
|
||||
|
||||
The configuration options are controlled by modifying the public members
|
||||
of c. See ConfigState for options documentation.
|
||||
|
||||
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||
get the formatted result as a string.
|
||||
*/
|
||||
func (c *ConfigState) Dump(a ...interface{}) {
|
||||
fdump(c, os.Stdout, a...)
|
||||
}
|
||||
|
||||
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||
// as Dump.
|
||||
func (c *ConfigState) Sdump(a ...interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
fdump(c, &buf, a...)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||
// length with each argument converted to a spew Formatter interface using
|
||||
// the ConfigState associated with s.
|
||||
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
|
||||
formatters = make([]interface{}, len(args))
|
||||
for index, arg := range args {
|
||||
formatters[index] = newFormatter(c, arg)
|
||||
}
|
||||
return formatters
|
||||
}
|
||||
|
||||
// NewDefaultConfig returns a ConfigState with the following default settings.
|
||||
//
|
||||
// Indent: " "
|
||||
// MaxDepth: 0
|
||||
// DisableMethods: false
|
||||
// DisablePointerMethods: false
|
||||
// ContinueOnMethod: false
|
||||
// SortKeys: false
|
||||
func NewDefaultConfig() *ConfigState {
|
||||
return &ConfigState{Indent: " "}
|
||||
}
|
||||
|
|
@ -0,0 +1,196 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
/*
|
||||
Package spew implements a deep pretty printer for Go data structures to aid in
|
||||
debugging.
|
||||
|
||||
A quick overview of the additional features spew provides over the built-in
|
||||
printing facilities for Go data types are as follows:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output (only when using
|
||||
Dump style)
|
||||
|
||||
There are two different approaches spew allows for dumping Go data structures:
|
||||
|
||||
* Dump style which prints with newlines, customizable indentation,
|
||||
and additional debug information such as types and all pointer addresses
|
||||
used to indirect to the final value
|
||||
* A custom Formatter interface that integrates cleanly with the standard fmt
|
||||
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
|
||||
similar to the default %v while providing the additional functionality
|
||||
outlined above and passing unsupported format verbs such as %x and %q
|
||||
along to fmt
|
||||
|
||||
Quick Start
|
||||
|
||||
This section demonstrates how to quickly get started with spew. See the
|
||||
sections below for further details on formatting and configuration options.
|
||||
|
||||
To dump a variable with full newlines, indentation, type, and pointer
|
||||
information use Dump, Fdump, or Sdump:
|
||||
spew.Dump(myVar1, myVar2, ...)
|
||||
spew.Fdump(someWriter, myVar1, myVar2, ...)
|
||||
str := spew.Sdump(myVar1, myVar2, ...)
|
||||
|
||||
Alternatively, if you would prefer to use format strings with a compacted inline
|
||||
printing style, use the convenience wrappers Printf, Fprintf, etc with
|
||||
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
|
||||
%#+v (adds types and pointer addresses):
|
||||
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
|
||||
Configuration Options
|
||||
|
||||
Configuration of spew is handled by fields in the ConfigState type. For
|
||||
convenience, all of the top-level functions use a global state available
|
||||
via the spew.Config global.
|
||||
|
||||
It is also possible to create a ConfigState instance that provides methods
|
||||
equivalent to the top-level functions. This allows concurrent configuration
|
||||
options. See the ConfigState documentation for more details.
|
||||
|
||||
The following configuration options are available:
|
||||
* Indent
|
||||
String to use for each indentation level for Dump functions.
|
||||
It is a single space by default. A popular alternative is "\t".
|
||||
|
||||
* MaxDepth
|
||||
Maximum number of levels to descend into nested data structures.
|
||||
There is no limit by default.
|
||||
|
||||
* DisableMethods
|
||||
Disables invocation of error and Stringer interface methods.
|
||||
Method invocation is enabled by default.
|
||||
|
||||
* DisablePointerMethods
|
||||
Disables invocation of error and Stringer interface methods on types
|
||||
which only accept pointer receivers from non-pointer variables.
|
||||
Pointer method invocation is enabled by default.
|
||||
|
||||
* ContinueOnMethod
|
||||
Enables recursion into types after invoking error and Stringer interface
|
||||
methods. Recursion after method invocation is disabled by default.
|
||||
|
||||
* SortKeys
|
||||
Specifies map keys should be sorted before being printed. Use
|
||||
this to have a more deterministic, diffable output. Note that
|
||||
only native types (bool, int, uint, floats, uintptr and string)
|
||||
are supported with other types sorted according to the
|
||||
reflect.Value.String() output which guarantees display stability.
|
||||
Natural map order is used by default.
|
||||
|
||||
Dump Usage
|
||||
|
||||
Simply call spew.Dump with a list of variables you want to dump:
|
||||
|
||||
spew.Dump(myVar1, myVar2, ...)
|
||||
|
||||
You may also call spew.Fdump if you would prefer to output to an arbitrary
|
||||
io.Writer. For example, to dump to standard error:
|
||||
|
||||
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
|
||||
|
||||
A third option is to call spew.Sdump to get the formatted output as a string:
|
||||
|
||||
str := spew.Sdump(myVar1, myVar2, ...)
|
||||
|
||||
Sample Dump Output
|
||||
|
||||
See the Dump example for details on the setup of the types and variables being
|
||||
shown here.
|
||||
|
||||
(main.Foo) {
|
||||
unexportedField: (*main.Bar)(0xf84002e210)({
|
||||
flag: (main.Flag) flagTwo,
|
||||
data: (uintptr) <nil>
|
||||
}),
|
||||
ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
(string) (len=3) "one": (bool) true
|
||||
}
|
||||
}
|
||||
|
||||
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
|
||||
command as shown.
|
||||
([]uint8) (len=32 cap=32) {
|
||||
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||
00000020 31 32 |12|
|
||||
}
|
||||
|
||||
Custom Formatter
|
||||
|
||||
Spew provides a custom formatter that implements the fmt.Formatter interface
|
||||
so that it integrates cleanly with standard fmt package printing functions. The
|
||||
formatter is useful for inline printing of smaller data types similar to the
|
||||
standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Custom Formatter Usage
|
||||
|
||||
The simplest way to make use of the spew custom formatter is to call one of the
|
||||
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
|
||||
functions have syntax you are most likely already familiar with:
|
||||
|
||||
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
spew.Println(myVar, myVar2)
|
||||
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||
|
||||
See the Index for the full list convenience functions.
|
||||
|
||||
Sample Formatter Output
|
||||
|
||||
Double pointer to a uint8:
|
||||
%v: <**>5
|
||||
%+v: <**>(0xf8400420d0->0xf8400420c8)5
|
||||
%#v: (**uint8)5
|
||||
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
|
||||
|
||||
Pointer to circular struct with a uint8 field and a pointer to itself:
|
||||
%v: <*>{1 <*><shown>}
|
||||
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
|
||||
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
|
||||
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
|
||||
|
||||
See the Printf example for details on the setup of variables being shown
|
||||
here.
|
||||
|
||||
Errors
|
||||
|
||||
Since it is possible for custom Stringer/error interfaces to panic, spew
|
||||
detects them and handles them internally by printing the panic information
|
||||
inline with the output. Since spew is intended to provide deep pretty printing
|
||||
capabilities on structures, it intentionally does not return any errors.
|
||||
*/
|
||||
package spew
|
||||
|
|
@ -0,0 +1,500 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
// uint8Type is a reflect.Type representing a uint8. It is used to
|
||||
// convert cgo types to uint8 slices for hexdumping.
|
||||
uint8Type = reflect.TypeOf(uint8(0))
|
||||
|
||||
// cCharRE is a regular expression that matches a cgo char.
|
||||
// It is used to detect character arrays to hexdump them.
|
||||
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
|
||||
|
||||
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
|
||||
// char. It is used to detect unsigned character arrays to hexdump
|
||||
// them.
|
||||
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
|
||||
|
||||
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
|
||||
// It is used to detect uint8_t arrays to hexdump them.
|
||||
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
|
||||
)
|
||||
|
||||
// dumpState contains information about the state of a dump operation.
|
||||
type dumpState struct {
|
||||
w io.Writer
|
||||
depth int
|
||||
pointers map[uintptr]int
|
||||
ignoreNextType bool
|
||||
ignoreNextIndent bool
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// indent performs indentation according to the depth level and cs.Indent
|
||||
// option.
|
||||
func (d *dumpState) indent() {
|
||||
if d.ignoreNextIndent {
|
||||
d.ignoreNextIndent = false
|
||||
return
|
||||
}
|
||||
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
|
||||
}
|
||||
|
||||
// unpackValue returns values inside of non-nil interfaces when possible.
|
||||
// This is useful for data types like structs, arrays, slices, and maps which
|
||||
// can contain varying types packed inside an interface.
|
||||
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Interface && !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// dumpPtr handles formatting of pointers by indirecting them as necessary.
|
||||
func (d *dumpState) dumpPtr(v reflect.Value) {
|
||||
// Remove pointers at or below the current depth from map used to detect
|
||||
// circular refs.
|
||||
for k, depth := range d.pointers {
|
||||
if depth >= d.depth {
|
||||
delete(d.pointers, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep list of all dereferenced pointers to show later.
|
||||
pointerChain := make([]uintptr, 0)
|
||||
|
||||
// Figure out how many levels of indirection there are by dereferencing
|
||||
// pointers and unpacking interfaces down the chain while detecting circular
|
||||
// references.
|
||||
nilFound := false
|
||||
cycleFound := false
|
||||
indirects := 0
|
||||
ve := v
|
||||
for ve.Kind() == reflect.Ptr {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
indirects++
|
||||
addr := ve.Pointer()
|
||||
pointerChain = append(pointerChain, addr)
|
||||
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
|
||||
cycleFound = true
|
||||
indirects--
|
||||
break
|
||||
}
|
||||
d.pointers[addr] = d.depth
|
||||
|
||||
ve = ve.Elem()
|
||||
if ve.Kind() == reflect.Interface {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
ve = ve.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// Display type information.
|
||||
d.w.Write(openParenBytes)
|
||||
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||
d.w.Write([]byte(ve.Type().String()))
|
||||
d.w.Write(closeParenBytes)
|
||||
|
||||
// Display pointer information.
|
||||
if len(pointerChain) > 0 {
|
||||
d.w.Write(openParenBytes)
|
||||
for i, addr := range pointerChain {
|
||||
if i > 0 {
|
||||
d.w.Write(pointerChainBytes)
|
||||
}
|
||||
printHexPtr(d.w, addr)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// Display dereferenced value.
|
||||
d.w.Write(openParenBytes)
|
||||
switch {
|
||||
case nilFound == true:
|
||||
d.w.Write(nilAngleBytes)
|
||||
|
||||
case cycleFound == true:
|
||||
d.w.Write(circularBytes)
|
||||
|
||||
default:
|
||||
d.ignoreNextType = true
|
||||
d.dump(ve)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
|
||||
// reflection) arrays and slices are dumped in hexdump -C fashion.
|
||||
func (d *dumpState) dumpSlice(v reflect.Value) {
|
||||
// Determine whether this type should be hex dumped or not. Also,
|
||||
// for types which should be hexdumped, try to use the underlying data
|
||||
// first, then fall back to trying to convert them to a uint8 slice.
|
||||
var buf []uint8
|
||||
doConvert := false
|
||||
doHexDump := false
|
||||
numEntries := v.Len()
|
||||
if numEntries > 0 {
|
||||
vt := v.Index(0).Type()
|
||||
vts := vt.String()
|
||||
switch {
|
||||
// C types that need to be converted.
|
||||
case cCharRE.MatchString(vts):
|
||||
fallthrough
|
||||
case cUnsignedCharRE.MatchString(vts):
|
||||
fallthrough
|
||||
case cUint8tCharRE.MatchString(vts):
|
||||
doConvert = true
|
||||
|
||||
// Try to use existing uint8 slices and fall back to converting
|
||||
// and copying if that fails.
|
||||
case vt.Kind() == reflect.Uint8:
|
||||
// We need an addressable interface to convert the type back
|
||||
// into a byte slice. However, the reflect package won't give
|
||||
// us an interface on certain things like unexported struct
|
||||
// fields in order to enforce visibility rules. We use unsafe
|
||||
// to bypass these restrictions since this package does not
|
||||
// mutate the values.
|
||||
vs := v
|
||||
if !vs.CanInterface() || !vs.CanAddr() {
|
||||
vs = unsafeReflectValue(vs)
|
||||
}
|
||||
vs = vs.Slice(0, numEntries)
|
||||
|
||||
// Use the existing uint8 slice if it can be type
|
||||
// asserted.
|
||||
iface := vs.Interface()
|
||||
if slice, ok := iface.([]uint8); ok {
|
||||
buf = slice
|
||||
doHexDump = true
|
||||
break
|
||||
}
|
||||
|
||||
// The underlying data needs to be converted if it can't
|
||||
// be type asserted to a uint8 slice.
|
||||
doConvert = true
|
||||
}
|
||||
|
||||
// Copy and convert the underlying type if needed.
|
||||
if doConvert && vt.ConvertibleTo(uint8Type) {
|
||||
// Convert and copy each element into a uint8 byte
|
||||
// slice.
|
||||
buf = make([]uint8, numEntries)
|
||||
for i := 0; i < numEntries; i++ {
|
||||
vv := v.Index(i)
|
||||
buf[i] = uint8(vv.Convert(uint8Type).Uint())
|
||||
}
|
||||
doHexDump = true
|
||||
}
|
||||
}
|
||||
|
||||
// Hexdump the entire slice as needed.
|
||||
if doHexDump {
|
||||
indent := strings.Repeat(d.cs.Indent, d.depth)
|
||||
str := indent + hex.Dump(buf)
|
||||
str = strings.Replace(str, "\n", "\n"+indent, -1)
|
||||
str = strings.TrimRight(str, d.cs.Indent)
|
||||
d.w.Write([]byte(str))
|
||||
return
|
||||
}
|
||||
|
||||
// Recursively call dump for each item.
|
||||
for i := 0; i < numEntries; i++ {
|
||||
d.dump(d.unpackValue(v.Index(i)))
|
||||
if i < (numEntries - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dump is the main workhorse for dumping a value. It uses the passed reflect
|
||||
// value to figure out what kind of object we are dealing with and formats it
|
||||
// appropriately. It is a recursive function, however circular data structures
|
||||
// are detected and handled properly.
|
||||
func (d *dumpState) dump(v reflect.Value) {
|
||||
// Handle invalid reflect values immediately.
|
||||
kind := v.Kind()
|
||||
if kind == reflect.Invalid {
|
||||
d.w.Write(invalidAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle pointers specially.
|
||||
if kind == reflect.Ptr {
|
||||
d.indent()
|
||||
d.dumpPtr(v)
|
||||
return
|
||||
}
|
||||
|
||||
// Print type information unless already handled elsewhere.
|
||||
if !d.ignoreNextType {
|
||||
d.indent()
|
||||
d.w.Write(openParenBytes)
|
||||
d.w.Write([]byte(v.Type().String()))
|
||||
d.w.Write(closeParenBytes)
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
d.ignoreNextType = false
|
||||
|
||||
// Display length and capacity if the built-in len and cap functions
|
||||
// work with the value's kind and the len/cap itself is non-zero.
|
||||
valueLen, valueCap := 0, 0
|
||||
switch v.Kind() {
|
||||
case reflect.Array, reflect.Slice, reflect.Chan:
|
||||
valueLen, valueCap = v.Len(), v.Cap()
|
||||
case reflect.Map, reflect.String:
|
||||
valueLen = v.Len()
|
||||
}
|
||||
if valueLen != 0 || valueCap != 0 {
|
||||
d.w.Write(openParenBytes)
|
||||
if valueLen != 0 {
|
||||
d.w.Write(lenEqualsBytes)
|
||||
printInt(d.w, int64(valueLen), 10)
|
||||
}
|
||||
if valueCap != 0 {
|
||||
if valueLen != 0 {
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
d.w.Write(capEqualsBytes)
|
||||
printInt(d.w, int64(valueCap), 10)
|
||||
}
|
||||
d.w.Write(closeParenBytes)
|
||||
d.w.Write(spaceBytes)
|
||||
}
|
||||
|
||||
// Call Stringer/error interfaces if they exist and the handle methods flag
|
||||
// is enabled
|
||||
if !d.cs.DisableMethods {
|
||||
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||
if handled := handleMethods(d.cs, d.w, v); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Invalid:
|
||||
// Do nothing. We should never get here since invalid has already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Bool:
|
||||
printBool(d.w, v.Bool())
|
||||
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
printInt(d.w, v.Int(), 10)
|
||||
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
printUint(d.w, v.Uint(), 10)
|
||||
|
||||
case reflect.Float32:
|
||||
printFloat(d.w, v.Float(), 32)
|
||||
|
||||
case reflect.Float64:
|
||||
printFloat(d.w, v.Float(), 64)
|
||||
|
||||
case reflect.Complex64:
|
||||
printComplex(d.w, v.Complex(), 32)
|
||||
|
||||
case reflect.Complex128:
|
||||
printComplex(d.w, v.Complex(), 64)
|
||||
|
||||
case reflect.Slice:
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Array:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
d.dumpSlice(v)
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.String:
|
||||
d.w.Write([]byte(strconv.Quote(v.String())))
|
||||
|
||||
case reflect.Interface:
|
||||
// The only time we should get here is for nil interfaces due to
|
||||
// unpackValue calls.
|
||||
if v.IsNil() {
|
||||
d.w.Write(nilAngleBytes)
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// Do nothing. We should never get here since pointers have already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Map:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
numEntries := v.Len()
|
||||
keys := v.MapKeys()
|
||||
if d.cs.SortKeys {
|
||||
sortValues(keys)
|
||||
}
|
||||
for i, key := range keys {
|
||||
d.dump(d.unpackValue(key))
|
||||
d.w.Write(colonSpaceBytes)
|
||||
d.ignoreNextIndent = true
|
||||
d.dump(d.unpackValue(v.MapIndex(key)))
|
||||
if i < (numEntries - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Struct:
|
||||
d.w.Write(openBraceNewlineBytes)
|
||||
d.depth++
|
||||
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||
d.indent()
|
||||
d.w.Write(maxNewlineBytes)
|
||||
} else {
|
||||
vt := v.Type()
|
||||
numFields := v.NumField()
|
||||
for i := 0; i < numFields; i++ {
|
||||
d.indent()
|
||||
vtf := vt.Field(i)
|
||||
d.w.Write([]byte(vtf.Name))
|
||||
d.w.Write(colonSpaceBytes)
|
||||
d.ignoreNextIndent = true
|
||||
d.dump(d.unpackValue(v.Field(i)))
|
||||
if i < (numFields - 1) {
|
||||
d.w.Write(commaNewlineBytes)
|
||||
} else {
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
d.depth--
|
||||
d.indent()
|
||||
d.w.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Uintptr:
|
||||
printHexPtr(d.w, uintptr(v.Uint()))
|
||||
|
||||
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||
printHexPtr(d.w, v.Pointer())
|
||||
|
||||
// There were not any other types at the time this code was written, but
|
||||
// fall back to letting the default fmt package handle it in case any new
|
||||
// types are added.
|
||||
default:
|
||||
if v.CanInterface() {
|
||||
fmt.Fprintf(d.w, "%v", v.Interface())
|
||||
} else {
|
||||
fmt.Fprintf(d.w, "%v", v.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fdump is a helper function to consolidate the logic from the various public
|
||||
// methods which take varying writers and config states.
|
||||
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
|
||||
for _, arg := range a {
|
||||
if arg == nil {
|
||||
w.Write(interfaceBytes)
|
||||
w.Write(spaceBytes)
|
||||
w.Write(nilAngleBytes)
|
||||
w.Write(newlineBytes)
|
||||
continue
|
||||
}
|
||||
|
||||
d := dumpState{w: w, cs: cs}
|
||||
d.pointers = make(map[uintptr]int)
|
||||
d.dump(reflect.ValueOf(arg))
|
||||
d.w.Write(newlineBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||
// exactly the same as Dump.
|
||||
func Fdump(w io.Writer, a ...interface{}) {
|
||||
fdump(&Config, w, a...)
|
||||
}
|
||||
|
||||
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||
// as Dump.
|
||||
func Sdump(a ...interface{}) string {
|
||||
var buf bytes.Buffer
|
||||
fdump(&Config, &buf, a...)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
/*
|
||||
Dump displays the passed parameters to standard out with newlines, customizable
|
||||
indentation, and additional debug information such as complete types and all
|
||||
pointer addresses used to indirect to the final value. It provides the
|
||||
following features over the built-in printing facilities provided by the fmt
|
||||
package:
|
||||
|
||||
* Pointers are dereferenced and followed
|
||||
* Circular data structures are detected and handled properly
|
||||
* Custom Stringer/error interfaces are optionally invoked, including
|
||||
on unexported types
|
||||
* Custom types which only implement the Stringer/error interfaces via
|
||||
a pointer receiver are optionally invoked when passing non-pointer
|
||||
variables
|
||||
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||
includes offsets, byte values in hex, and ASCII output
|
||||
|
||||
The configuration options are controlled by an exported package global,
|
||||
spew.Config. See ConfigState for options documentation.
|
||||
|
||||
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||
get the formatted result as a string.
|
||||
*/
|
||||
func Dump(a ...interface{}) {
|
||||
fdump(&Config, os.Stdout, a...)
|
||||
}
|
||||
|
|
@ -0,0 +1,978 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
/*
|
||||
Test Summary:
|
||||
NOTE: For each test, a nil pointer, a single pointer and double pointer to the
|
||||
base test element are also tested to ensure proper indirection across all types.
|
||||
|
||||
- Max int8, int16, int32, int64, int
|
||||
- Max uint8, uint16, uint32, uint64, uint
|
||||
- Boolean true and false
|
||||
- Standard complex64 and complex128
|
||||
- Array containing standard ints
|
||||
- Array containing type with custom formatter on pointer receiver only
|
||||
- Array containing interfaces
|
||||
- Array containing bytes
|
||||
- Slice containing standard float32 values
|
||||
- Slice containing type with custom formatter on pointer receiver only
|
||||
- Slice containing interfaces
|
||||
- Slice containing bytes
|
||||
- Nil slice
|
||||
- Standard string
|
||||
- Nil interface
|
||||
- Sub-interface
|
||||
- Map with string keys and int vals
|
||||
- Map with custom formatter type on pointer receiver only keys and vals
|
||||
- Map with interface keys and values
|
||||
- Map with nil interface value
|
||||
- Struct with primitives
|
||||
- Struct that contains another struct
|
||||
- Struct that contains custom type with Stringer pointer interface via both
|
||||
exported and unexported fields
|
||||
- Struct that contains embedded struct and field to same struct
|
||||
- Uintptr to 0 (null pointer)
|
||||
- Uintptr address of real variable
|
||||
- Unsafe.Pointer to 0 (null pointer)
|
||||
- Unsafe.Pointer to address of real variable
|
||||
- Nil channel
|
||||
- Standard int channel
|
||||
- Function with no params and no returns
|
||||
- Function with param and no returns
|
||||
- Function with multiple params and multiple returns
|
||||
- Struct that is circular through self referencing
|
||||
- Structs that are circular through cross referencing
|
||||
- Structs that are indirectly circular
|
||||
- Type that panics in its Stringer interface
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// dumpTest is used to describe a test to be perfomed against the Dump method.
|
||||
type dumpTest struct {
|
||||
in interface{}
|
||||
wants []string
|
||||
}
|
||||
|
||||
// dumpTests houses all of the tests to be performed against the Dump method.
|
||||
var dumpTests = make([]dumpTest, 0)
|
||||
|
||||
// addDumpTest is a helper method to append the passed input and desired result
|
||||
// to dumpTests
|
||||
func addDumpTest(in interface{}, wants ...string) {
|
||||
test := dumpTest{in, wants}
|
||||
dumpTests = append(dumpTests, test)
|
||||
}
|
||||
|
||||
func addIntDumpTests() {
|
||||
// Max int8.
|
||||
v := int8(127)
|
||||
nv := (*int8)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "int8"
|
||||
vs := "127"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Max int16.
|
||||
v2 := int16(32767)
|
||||
nv2 := (*int16)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "int16"
|
||||
v2s := "32767"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
|
||||
// Max int32.
|
||||
v3 := int32(2147483647)
|
||||
nv3 := (*int32)(nil)
|
||||
pv3 := &v3
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "int32"
|
||||
v3s := "2147483647"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
|
||||
|
||||
// Max int64.
|
||||
v4 := int64(9223372036854775807)
|
||||
nv4 := (*int64)(nil)
|
||||
pv4 := &v4
|
||||
v4Addr := fmt.Sprintf("%p", pv4)
|
||||
pv4Addr := fmt.Sprintf("%p", &pv4)
|
||||
v4t := "int64"
|
||||
v4s := "9223372036854775807"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
|
||||
|
||||
// Max int.
|
||||
v5 := int(2147483647)
|
||||
nv5 := (*int)(nil)
|
||||
pv5 := &v5
|
||||
v5Addr := fmt.Sprintf("%p", pv5)
|
||||
pv5Addr := fmt.Sprintf("%p", &pv5)
|
||||
v5t := "int"
|
||||
v5s := "2147483647"
|
||||
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
|
||||
addDumpTest(pv5, "(*"+v5t+")("+v5Addr+")("+v5s+")\n")
|
||||
addDumpTest(&pv5, "(**"+v5t+")("+pv5Addr+"->"+v5Addr+")("+v5s+")\n")
|
||||
addDumpTest(nv5, "(*"+v5t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addUintDumpTests() {
|
||||
// Max uint8.
|
||||
v := uint8(255)
|
||||
nv := (*uint8)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "uint8"
|
||||
vs := "255"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Max uint16.
|
||||
v2 := uint16(65535)
|
||||
nv2 := (*uint16)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "uint16"
|
||||
v2s := "65535"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
|
||||
// Max uint32.
|
||||
v3 := uint32(4294967295)
|
||||
nv3 := (*uint32)(nil)
|
||||
pv3 := &v3
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "uint32"
|
||||
v3s := "4294967295"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
|
||||
|
||||
// Max uint64.
|
||||
v4 := uint64(18446744073709551615)
|
||||
nv4 := (*uint64)(nil)
|
||||
pv4 := &v4
|
||||
v4Addr := fmt.Sprintf("%p", pv4)
|
||||
pv4Addr := fmt.Sprintf("%p", &pv4)
|
||||
v4t := "uint64"
|
||||
v4s := "18446744073709551615"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
|
||||
|
||||
// Max uint.
|
||||
v5 := uint(4294967295)
|
||||
nv5 := (*uint)(nil)
|
||||
pv5 := &v5
|
||||
v5Addr := fmt.Sprintf("%p", pv5)
|
||||
pv5Addr := fmt.Sprintf("%p", &pv5)
|
||||
v5t := "uint"
|
||||
v5s := "4294967295"
|
||||
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
|
||||
addDumpTest(pv5, "(*"+v5t+")("+v5Addr+")("+v5s+")\n")
|
||||
addDumpTest(&pv5, "(**"+v5t+")("+pv5Addr+"->"+v5Addr+")("+v5s+")\n")
|
||||
addDumpTest(nv5, "(*"+v5t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addBoolDumpTests() {
|
||||
// Boolean true.
|
||||
v := bool(true)
|
||||
nv := (*bool)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "bool"
|
||||
vs := "true"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Boolean false.
|
||||
v2 := bool(false)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "bool"
|
||||
v2s := "false"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
}
|
||||
|
||||
func addFloatDumpTests() {
|
||||
// Standard float32.
|
||||
v := float32(3.1415)
|
||||
nv := (*float32)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "float32"
|
||||
vs := "3.1415"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Standard float64.
|
||||
v2 := float64(3.1415926)
|
||||
nv2 := (*float64)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "float64"
|
||||
v2s := "3.1415926"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addComplexDumpTests() {
|
||||
// Standard complex64.
|
||||
v := complex(float32(6), -2)
|
||||
nv := (*complex64)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "complex64"
|
||||
vs := "(6-2i)"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Standard complex128.
|
||||
v2 := complex(float64(-6), 2)
|
||||
nv2 := (*complex128)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "complex128"
|
||||
v2s := "(-6+2i)"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addArrayDumpTests() {
|
||||
// Array containing standard ints.
|
||||
v := [3]int{1, 2, 3}
|
||||
vLen := fmt.Sprintf("%d", len(v))
|
||||
vCap := fmt.Sprintf("%d", cap(v))
|
||||
nv := (*[3]int)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "int"
|
||||
vs := "(len=" + vLen + " cap=" + vCap + ") {\n (" + vt + ") 1,\n (" +
|
||||
vt + ") 2,\n (" + vt + ") 3\n}"
|
||||
addDumpTest(v, "([3]"+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*[3]"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**[3]"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*[3]"+vt+")(<nil>)\n")
|
||||
|
||||
// Array containing type with custom formatter on pointer receiver only.
|
||||
v2i0 := pstringer("1")
|
||||
v2i1 := pstringer("2")
|
||||
v2i2 := pstringer("3")
|
||||
v2 := [3]pstringer{v2i0, v2i1, v2i2}
|
||||
v2i0Len := fmt.Sprintf("%d", len(v2i0))
|
||||
v2i1Len := fmt.Sprintf("%d", len(v2i1))
|
||||
v2i2Len := fmt.Sprintf("%d", len(v2i2))
|
||||
v2Len := fmt.Sprintf("%d", len(v2))
|
||||
v2Cap := fmt.Sprintf("%d", cap(v2))
|
||||
nv2 := (*[3]pstringer)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "spew_test.pstringer"
|
||||
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") {\n (" + v2t + ") (len=" +
|
||||
v2i0Len + ") stringer 1,\n (" + v2t + ") (len=" + v2i1Len +
|
||||
") stringer 2,\n (" + v2t + ") (len=" + v2i2Len + ") " +
|
||||
"stringer 3\n}"
|
||||
addDumpTest(v2, "([3]"+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*[3]"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**[3]"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*[3]"+v2t+")(<nil>)\n")
|
||||
|
||||
// Array containing interfaces.
|
||||
v3i0 := "one"
|
||||
v3 := [3]interface{}{v3i0, int(2), uint(3)}
|
||||
v3i0Len := fmt.Sprintf("%d", len(v3i0))
|
||||
v3Len := fmt.Sprintf("%d", len(v3))
|
||||
v3Cap := fmt.Sprintf("%d", cap(v3))
|
||||
nv3 := (*[3]interface{})(nil)
|
||||
pv3 := &v3
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "[3]interface {}"
|
||||
v3t2 := "string"
|
||||
v3t3 := "int"
|
||||
v3t4 := "uint"
|
||||
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") {\n (" + v3t2 + ") " +
|
||||
"(len=" + v3i0Len + ") \"one\",\n (" + v3t3 + ") 2,\n (" +
|
||||
v3t4 + ") 3\n}"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
|
||||
|
||||
// Array containing bytes.
|
||||
v4 := [34]byte{
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
|
||||
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
|
||||
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
|
||||
0x31, 0x32,
|
||||
}
|
||||
v4Len := fmt.Sprintf("%d", len(v4))
|
||||
v4Cap := fmt.Sprintf("%d", cap(v4))
|
||||
nv4 := (*[34]byte)(nil)
|
||||
pv4 := &v4
|
||||
v4Addr := fmt.Sprintf("%p", pv4)
|
||||
pv4Addr := fmt.Sprintf("%p", &pv4)
|
||||
v4t := "[34]uint8"
|
||||
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
|
||||
"{\n 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20" +
|
||||
" |............... |\n" +
|
||||
" 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30" +
|
||||
" |!\"#$%&'()*+,-./0|\n" +
|
||||
" 00000020 31 32 " +
|
||||
" |12|\n}"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addSliceDumpTests() {
|
||||
// Slice containing standard float32 values.
|
||||
v := []float32{3.14, 6.28, 12.56}
|
||||
vLen := fmt.Sprintf("%d", len(v))
|
||||
vCap := fmt.Sprintf("%d", cap(v))
|
||||
nv := (*[]float32)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "float32"
|
||||
vs := "(len=" + vLen + " cap=" + vCap + ") {\n (" + vt + ") 3.14,\n (" +
|
||||
vt + ") 6.28,\n (" + vt + ") 12.56\n}"
|
||||
addDumpTest(v, "([]"+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*[]"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**[]"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*[]"+vt+")(<nil>)\n")
|
||||
|
||||
// Slice containing type with custom formatter on pointer receiver only.
|
||||
v2i0 := pstringer("1")
|
||||
v2i1 := pstringer("2")
|
||||
v2i2 := pstringer("3")
|
||||
v2 := []pstringer{v2i0, v2i1, v2i2}
|
||||
v2i0Len := fmt.Sprintf("%d", len(v2i0))
|
||||
v2i1Len := fmt.Sprintf("%d", len(v2i1))
|
||||
v2i2Len := fmt.Sprintf("%d", len(v2i2))
|
||||
v2Len := fmt.Sprintf("%d", len(v2))
|
||||
v2Cap := fmt.Sprintf("%d", cap(v2))
|
||||
nv2 := (*[]pstringer)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "spew_test.pstringer"
|
||||
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") {\n (" + v2t + ") (len=" +
|
||||
v2i0Len + ") stringer 1,\n (" + v2t + ") (len=" + v2i1Len +
|
||||
") stringer 2,\n (" + v2t + ") (len=" + v2i2Len + ") " +
|
||||
"stringer 3\n}"
|
||||
addDumpTest(v2, "([]"+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*[]"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**[]"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*[]"+v2t+")(<nil>)\n")
|
||||
|
||||
// Slice containing interfaces.
|
||||
v3i0 := "one"
|
||||
v3 := []interface{}{v3i0, int(2), uint(3), nil}
|
||||
v3i0Len := fmt.Sprintf("%d", len(v3i0))
|
||||
v3Len := fmt.Sprintf("%d", len(v3))
|
||||
v3Cap := fmt.Sprintf("%d", cap(v3))
|
||||
nv3 := (*[]interface{})(nil)
|
||||
pv3 := &v3
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "[]interface {}"
|
||||
v3t2 := "string"
|
||||
v3t3 := "int"
|
||||
v3t4 := "uint"
|
||||
v3t5 := "interface {}"
|
||||
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") {\n (" + v3t2 + ") " +
|
||||
"(len=" + v3i0Len + ") \"one\",\n (" + v3t3 + ") 2,\n (" +
|
||||
v3t4 + ") 3,\n (" + v3t5 + ") <nil>\n}"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
|
||||
|
||||
// Slice containing bytes.
|
||||
v4 := []byte{
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
|
||||
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
|
||||
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
|
||||
0x31, 0x32,
|
||||
}
|
||||
v4Len := fmt.Sprintf("%d", len(v4))
|
||||
v4Cap := fmt.Sprintf("%d", cap(v4))
|
||||
nv4 := (*[]byte)(nil)
|
||||
pv4 := &v4
|
||||
v4Addr := fmt.Sprintf("%p", pv4)
|
||||
pv4Addr := fmt.Sprintf("%p", &pv4)
|
||||
v4t := "[]uint8"
|
||||
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
|
||||
"{\n 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20" +
|
||||
" |............... |\n" +
|
||||
" 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30" +
|
||||
" |!\"#$%&'()*+,-./0|\n" +
|
||||
" 00000020 31 32 " +
|
||||
" |12|\n}"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
|
||||
|
||||
// Nil slice.
|
||||
v5 := []int(nil)
|
||||
nv5 := (*[]int)(nil)
|
||||
pv5 := &v5
|
||||
v5Addr := fmt.Sprintf("%p", pv5)
|
||||
pv5Addr := fmt.Sprintf("%p", &pv5)
|
||||
v5t := "[]int"
|
||||
v5s := "<nil>"
|
||||
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
|
||||
addDumpTest(pv5, "(*"+v5t+")("+v5Addr+")("+v5s+")\n")
|
||||
addDumpTest(&pv5, "(**"+v5t+")("+pv5Addr+"->"+v5Addr+")("+v5s+")\n")
|
||||
addDumpTest(nv5, "(*"+v5t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addStringDumpTests() {
|
||||
// Standard string.
|
||||
v := "test"
|
||||
vLen := fmt.Sprintf("%d", len(v))
|
||||
nv := (*string)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "string"
|
||||
vs := "(len=" + vLen + ") \"test\""
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addInterfaceDumpTests() {
|
||||
// Nil interface.
|
||||
var v interface{}
|
||||
nv := (*interface{})(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "interface {}"
|
||||
vs := "<nil>"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Sub-interface.
|
||||
v2 := interface{}(uint16(65535))
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "uint16"
|
||||
v2s := "65535"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
}
|
||||
|
||||
func addMapDumpTests() {
|
||||
// Map with string keys and int vals.
|
||||
k := "one"
|
||||
kk := "two"
|
||||
m := map[string]int{k: 1, kk: 2}
|
||||
klen := fmt.Sprintf("%d", len(k)) // not kLen to shut golint up
|
||||
kkLen := fmt.Sprintf("%d", len(kk))
|
||||
mLen := fmt.Sprintf("%d", len(m))
|
||||
nm := (*map[string]int)(nil)
|
||||
pm := &m
|
||||
mAddr := fmt.Sprintf("%p", pm)
|
||||
pmAddr := fmt.Sprintf("%p", &pm)
|
||||
mt := "map[string]int"
|
||||
mt1 := "string"
|
||||
mt2 := "int"
|
||||
ms := "(len=" + mLen + ") {\n (" + mt1 + ") (len=" + klen + ") " +
|
||||
"\"one\": (" + mt2 + ") 1,\n (" + mt1 + ") (len=" + kkLen +
|
||||
") \"two\": (" + mt2 + ") 2\n}"
|
||||
ms2 := "(len=" + mLen + ") {\n (" + mt1 + ") (len=" + kkLen + ") " +
|
||||
"\"two\": (" + mt2 + ") 2,\n (" + mt1 + ") (len=" + klen +
|
||||
") \"one\": (" + mt2 + ") 1\n}"
|
||||
addDumpTest(m, "("+mt+") "+ms+"\n", "("+mt+") "+ms2+"\n")
|
||||
addDumpTest(pm, "(*"+mt+")("+mAddr+")("+ms+")\n",
|
||||
"(*"+mt+")("+mAddr+")("+ms2+")\n")
|
||||
addDumpTest(&pm, "(**"+mt+")("+pmAddr+"->"+mAddr+")("+ms+")\n",
|
||||
"(**"+mt+")("+pmAddr+"->"+mAddr+")("+ms2+")\n")
|
||||
addDumpTest(nm, "(*"+mt+")(<nil>)\n")
|
||||
|
||||
// Map with custom formatter type on pointer receiver only keys and vals.
|
||||
k2 := pstringer("one")
|
||||
v2 := pstringer("1")
|
||||
m2 := map[pstringer]pstringer{k2: v2}
|
||||
k2Len := fmt.Sprintf("%d", len(k2))
|
||||
v2Len := fmt.Sprintf("%d", len(v2))
|
||||
m2Len := fmt.Sprintf("%d", len(m2))
|
||||
nm2 := (*map[pstringer]pstringer)(nil)
|
||||
pm2 := &m2
|
||||
m2Addr := fmt.Sprintf("%p", pm2)
|
||||
pm2Addr := fmt.Sprintf("%p", &pm2)
|
||||
m2t := "map[spew_test.pstringer]spew_test.pstringer"
|
||||
m2t1 := "spew_test.pstringer"
|
||||
m2t2 := "spew_test.pstringer"
|
||||
m2s := "(len=" + m2Len + ") {\n (" + m2t1 + ") (len=" + k2Len + ") " +
|
||||
"stringer one: (" + m2t2 + ") (len=" + v2Len + ") stringer 1\n}"
|
||||
addDumpTest(m2, "("+m2t+") "+m2s+"\n")
|
||||
addDumpTest(pm2, "(*"+m2t+")("+m2Addr+")("+m2s+")\n")
|
||||
addDumpTest(&pm2, "(**"+m2t+")("+pm2Addr+"->"+m2Addr+")("+m2s+")\n")
|
||||
addDumpTest(nm2, "(*"+m2t+")(<nil>)\n")
|
||||
|
||||
// Map with interface keys and values.
|
||||
k3 := "one"
|
||||
k3Len := fmt.Sprintf("%d", len(k3))
|
||||
m3 := map[interface{}]interface{}{k3: 1}
|
||||
m3Len := fmt.Sprintf("%d", len(m3))
|
||||
nm3 := (*map[interface{}]interface{})(nil)
|
||||
pm3 := &m3
|
||||
m3Addr := fmt.Sprintf("%p", pm3)
|
||||
pm3Addr := fmt.Sprintf("%p", &pm3)
|
||||
m3t := "map[interface {}]interface {}"
|
||||
m3t1 := "string"
|
||||
m3t2 := "int"
|
||||
m3s := "(len=" + m3Len + ") {\n (" + m3t1 + ") (len=" + k3Len + ") " +
|
||||
"\"one\": (" + m3t2 + ") 1\n}"
|
||||
addDumpTest(m3, "("+m3t+") "+m3s+"\n")
|
||||
addDumpTest(pm3, "(*"+m3t+")("+m3Addr+")("+m3s+")\n")
|
||||
addDumpTest(&pm3, "(**"+m3t+")("+pm3Addr+"->"+m3Addr+")("+m3s+")\n")
|
||||
addDumpTest(nm3, "(*"+m3t+")(<nil>)\n")
|
||||
|
||||
// Map with nil interface value.
|
||||
k4 := "nil"
|
||||
k4Len := fmt.Sprintf("%d", len(k4))
|
||||
m4 := map[string]interface{}{k4: nil}
|
||||
m4Len := fmt.Sprintf("%d", len(m4))
|
||||
nm4 := (*map[string]interface{})(nil)
|
||||
pm4 := &m4
|
||||
m4Addr := fmt.Sprintf("%p", pm4)
|
||||
pm4Addr := fmt.Sprintf("%p", &pm4)
|
||||
m4t := "map[string]interface {}"
|
||||
m4t1 := "string"
|
||||
m4t2 := "interface {}"
|
||||
m4s := "(len=" + m4Len + ") {\n (" + m4t1 + ") (len=" + k4Len + ")" +
|
||||
" \"nil\": (" + m4t2 + ") <nil>\n}"
|
||||
addDumpTest(m4, "("+m4t+") "+m4s+"\n")
|
||||
addDumpTest(pm4, "(*"+m4t+")("+m4Addr+")("+m4s+")\n")
|
||||
addDumpTest(&pm4, "(**"+m4t+")("+pm4Addr+"->"+m4Addr+")("+m4s+")\n")
|
||||
addDumpTest(nm4, "(*"+m4t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addStructDumpTests() {
|
||||
// Struct with primitives.
|
||||
type s1 struct {
|
||||
a int8
|
||||
b uint8
|
||||
}
|
||||
v := s1{127, 255}
|
||||
nv := (*s1)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "spew_test.s1"
|
||||
vt2 := "int8"
|
||||
vt3 := "uint8"
|
||||
vs := "{\n a: (" + vt2 + ") 127,\n b: (" + vt3 + ") 255\n}"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Struct that contains another struct.
|
||||
type s2 struct {
|
||||
s1 s1
|
||||
b bool
|
||||
}
|
||||
v2 := s2{s1{127, 255}, true}
|
||||
nv2 := (*s2)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "spew_test.s2"
|
||||
v2t2 := "spew_test.s1"
|
||||
v2t3 := "int8"
|
||||
v2t4 := "uint8"
|
||||
v2t5 := "bool"
|
||||
v2s := "{\n s1: (" + v2t2 + ") {\n a: (" + v2t3 + ") 127,\n b: (" +
|
||||
v2t4 + ") 255\n },\n b: (" + v2t5 + ") true\n}"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
|
||||
// Struct that contains custom type with Stringer pointer interface via both
|
||||
// exported and unexported fields.
|
||||
type s3 struct {
|
||||
s pstringer
|
||||
S pstringer
|
||||
}
|
||||
v3 := s3{"test", "test2"}
|
||||
nv3 := (*s3)(nil)
|
||||
pv3 := &v3
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "spew_test.s3"
|
||||
v3t2 := "spew_test.pstringer"
|
||||
v3s := "{\n s: (" + v3t2 + ") (len=4) stringer test,\n S: (" + v3t2 +
|
||||
") (len=5) stringer test2\n}"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
|
||||
|
||||
// Struct that contains embedded struct and field to same struct.
|
||||
e := embed{"embedstr"}
|
||||
eLen := fmt.Sprintf("%d", len("embedstr"))
|
||||
v4 := embedwrap{embed: &e, e: &e}
|
||||
nv4 := (*embedwrap)(nil)
|
||||
pv4 := &v4
|
||||
eAddr := fmt.Sprintf("%p", &e)
|
||||
v4Addr := fmt.Sprintf("%p", pv4)
|
||||
pv4Addr := fmt.Sprintf("%p", &pv4)
|
||||
v4t := "spew_test.embedwrap"
|
||||
v4t2 := "spew_test.embed"
|
||||
v4t3 := "string"
|
||||
v4s := "{\n embed: (*" + v4t2 + ")(" + eAddr + ")({\n a: (" + v4t3 +
|
||||
") (len=" + eLen + ") \"embedstr\"\n }),\n e: (*" + v4t2 +
|
||||
")(" + eAddr + ")({\n a: (" + v4t3 + ") (len=" + eLen + ")" +
|
||||
" \"embedstr\"\n })\n}"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
|
||||
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addUintptrDumpTests() {
|
||||
// Null pointer.
|
||||
v := uintptr(0)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "uintptr"
|
||||
vs := "<nil>"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
|
||||
// Address of real variable.
|
||||
i := 1
|
||||
v2 := uintptr(unsafe.Pointer(&i))
|
||||
nv2 := (*uintptr)(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "uintptr"
|
||||
v2s := fmt.Sprintf("%p", &i)
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addUnsafePointerDumpTests() {
|
||||
// Null pointer.
|
||||
v := unsafe.Pointer(uintptr(0))
|
||||
nv := (*unsafe.Pointer)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "unsafe.Pointer"
|
||||
vs := "<nil>"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Address of real variable.
|
||||
i := 1
|
||||
v2 := unsafe.Pointer(&i)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "unsafe.Pointer"
|
||||
v2s := fmt.Sprintf("%p", &i)
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addChanDumpTests() {
|
||||
// Nil channel.
|
||||
var v chan int
|
||||
pv := &v
|
||||
nv := (*chan int)(nil)
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "chan int"
|
||||
vs := "<nil>"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Real channel.
|
||||
v2 := make(chan int)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "chan int"
|
||||
v2s := fmt.Sprintf("%p", v2)
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
}
|
||||
|
||||
func addFuncDumpTests() {
|
||||
// Function with no params and no returns.
|
||||
v := addIntDumpTests
|
||||
nv := (*func())(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "func()"
|
||||
vs := fmt.Sprintf("%p", v)
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
|
||||
// Function with param and no returns.
|
||||
v2 := TestDump
|
||||
nv2 := (*func(*testing.T))(nil)
|
||||
pv2 := &v2
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "func(*testing.T)"
|
||||
v2s := fmt.Sprintf("%p", v2)
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
|
||||
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
|
||||
|
||||
// Function with multiple params and multiple returns.
|
||||
var v3 = func(i int, s string) (b bool, err error) {
|
||||
return true, nil
|
||||
}
|
||||
nv3 := (*func(int, string) (bool, error))(nil)
|
||||
pv3 := &v3
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "func(int, string) (bool, error)"
|
||||
v3s := fmt.Sprintf("%p", v3)
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
|
||||
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addCircularDumpTests() {
|
||||
// Struct that is circular through self referencing.
|
||||
type circular struct {
|
||||
c *circular
|
||||
}
|
||||
v := circular{nil}
|
||||
v.c = &v
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "spew_test.circular"
|
||||
vs := "{\n c: (*" + vt + ")(" + vAddr + ")({\n c: (*" + vt + ")(" +
|
||||
vAddr + ")(<already shown>)\n })\n}"
|
||||
vs2 := "{\n c: (*" + vt + ")(" + vAddr + ")(<already shown>)\n}"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs2+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs2+")\n")
|
||||
|
||||
// Structs that are circular through cross referencing.
|
||||
v2 := xref1{nil}
|
||||
ts2 := xref2{&v2}
|
||||
v2.ps2 = &ts2
|
||||
pv2 := &v2
|
||||
ts2Addr := fmt.Sprintf("%p", &ts2)
|
||||
v2Addr := fmt.Sprintf("%p", pv2)
|
||||
pv2Addr := fmt.Sprintf("%p", &pv2)
|
||||
v2t := "spew_test.xref1"
|
||||
v2t2 := "spew_test.xref2"
|
||||
v2s := "{\n ps2: (*" + v2t2 + ")(" + ts2Addr + ")({\n ps1: (*" + v2t +
|
||||
")(" + v2Addr + ")({\n ps2: (*" + v2t2 + ")(" + ts2Addr +
|
||||
")(<already shown>)\n })\n })\n}"
|
||||
v2s2 := "{\n ps2: (*" + v2t2 + ")(" + ts2Addr + ")({\n ps1: (*" + v2t +
|
||||
")(" + v2Addr + ")(<already shown>)\n })\n}"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s2+")\n")
|
||||
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s2+")\n")
|
||||
|
||||
// Structs that are indirectly circular.
|
||||
v3 := indirCir1{nil}
|
||||
tic2 := indirCir2{nil}
|
||||
tic3 := indirCir3{&v3}
|
||||
tic2.ps3 = &tic3
|
||||
v3.ps2 = &tic2
|
||||
pv3 := &v3
|
||||
tic2Addr := fmt.Sprintf("%p", &tic2)
|
||||
tic3Addr := fmt.Sprintf("%p", &tic3)
|
||||
v3Addr := fmt.Sprintf("%p", pv3)
|
||||
pv3Addr := fmt.Sprintf("%p", &pv3)
|
||||
v3t := "spew_test.indirCir1"
|
||||
v3t2 := "spew_test.indirCir2"
|
||||
v3t3 := "spew_test.indirCir3"
|
||||
v3s := "{\n ps2: (*" + v3t2 + ")(" + tic2Addr + ")({\n ps3: (*" + v3t3 +
|
||||
")(" + tic3Addr + ")({\n ps1: (*" + v3t + ")(" + v3Addr +
|
||||
")({\n ps2: (*" + v3t2 + ")(" + tic2Addr +
|
||||
")(<already shown>)\n })\n })\n })\n}"
|
||||
v3s2 := "{\n ps2: (*" + v3t2 + ")(" + tic2Addr + ")({\n ps3: (*" + v3t3 +
|
||||
")(" + tic3Addr + ")({\n ps1: (*" + v3t + ")(" + v3Addr +
|
||||
")(<already shown>)\n })\n })\n}"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s2+")\n")
|
||||
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s2+")\n")
|
||||
}
|
||||
|
||||
func addPanicDumpTests() {
|
||||
// Type that panics in its Stringer interface.
|
||||
v := panicer(127)
|
||||
nv := (*panicer)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "spew_test.panicer"
|
||||
vs := "(PANIC=test panic)127"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
}
|
||||
|
||||
func addErrorDumpTests() {
|
||||
// Type that has a custom Error interface.
|
||||
v := customError(127)
|
||||
nv := (*customError)(nil)
|
||||
pv := &v
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "spew_test.customError"
|
||||
vs := "error: 127"
|
||||
addDumpTest(v, "("+vt+") "+vs+"\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
|
||||
}
|
||||
|
||||
// TestDump executes all of the tests described by dumpTests.
|
||||
func TestDump(t *testing.T) {
|
||||
// Setup tests.
|
||||
addIntDumpTests()
|
||||
addUintDumpTests()
|
||||
addBoolDumpTests()
|
||||
addFloatDumpTests()
|
||||
addComplexDumpTests()
|
||||
addArrayDumpTests()
|
||||
addSliceDumpTests()
|
||||
addStringDumpTests()
|
||||
addInterfaceDumpTests()
|
||||
addMapDumpTests()
|
||||
addStructDumpTests()
|
||||
addUintptrDumpTests()
|
||||
addUnsafePointerDumpTests()
|
||||
addChanDumpTests()
|
||||
addFuncDumpTests()
|
||||
addCircularDumpTests()
|
||||
addPanicDumpTests()
|
||||
addErrorDumpTests()
|
||||
addCgoDumpTests()
|
||||
|
||||
t.Logf("Running %d tests", len(dumpTests))
|
||||
for i, test := range dumpTests {
|
||||
buf := new(bytes.Buffer)
|
||||
spew.Fdump(buf, test.in)
|
||||
s := buf.String()
|
||||
if testFailed(s, test.wants) {
|
||||
t.Errorf("Dump #%d\n got: %s %s", i, s, stringizeWants(test.wants))
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDumpSortedKeys(t *testing.T) {
|
||||
cfg := spew.ConfigState{SortKeys: true}
|
||||
s := cfg.Sdump(map[int]string{1: "1", 3: "3", 2: "2"})
|
||||
expected := `(map[int]string) (len=3) {
|
||||
(int) 1: (string) (len=1) "1",
|
||||
(int) 2: (string) (len=1) "2",
|
||||
(int) 3: (string) (len=1) "3"
|
||||
}
|
||||
`
|
||||
if s != expected {
|
||||
t.Errorf("Sorted keys mismatch:\n %v %v", s, expected)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when both cgo is supported and "-tags testcgo" is added to the go test
|
||||
// command line. This means the cgo tests are only added (and hence run) when
|
||||
// specifially requested. This configuration is used because spew itself
|
||||
// does not require cgo to run even though it does handle certain cgo types
|
||||
// specially. Rather than forcing all clients to require cgo and an external
|
||||
// C compiler just to run the tests, this scheme makes them optional.
|
||||
// +build cgo,testcgo
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew/testdata"
|
||||
)
|
||||
|
||||
func addCgoDumpTests() {
|
||||
// C char pointer.
|
||||
v := testdata.GetCgoCharPointer()
|
||||
nv := testdata.GetCgoNullCharPointer()
|
||||
pv := &v
|
||||
vcAddr := fmt.Sprintf("%p", v)
|
||||
vAddr := fmt.Sprintf("%p", pv)
|
||||
pvAddr := fmt.Sprintf("%p", &pv)
|
||||
vt := "*testdata._Ctype_char"
|
||||
vs := "116"
|
||||
addDumpTest(v, "("+vt+")("+vcAddr+")("+vs+")\n")
|
||||
addDumpTest(pv, "(*"+vt+")("+vAddr+"->"+vcAddr+")("+vs+")\n")
|
||||
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+"->"+vcAddr+")("+vs+")\n")
|
||||
addDumpTest(nv, "("+vt+")(<nil>)\n")
|
||||
|
||||
// C char array.
|
||||
v2, v2l, v2c := testdata.GetCgoCharArray()
|
||||
v2Len := fmt.Sprintf("%d", v2l)
|
||||
v2Cap := fmt.Sprintf("%d", v2c)
|
||||
v2t := "[6]testdata._Ctype_char"
|
||||
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 32 00 " +
|
||||
" |test2.|\n}"
|
||||
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
|
||||
|
||||
// C unsigned char array.
|
||||
v3, v3l, v3c := testdata.GetCgoUnsignedCharArray()
|
||||
v3Len := fmt.Sprintf("%d", v3l)
|
||||
v3Cap := fmt.Sprintf("%d", v3c)
|
||||
v3t := "[6]testdata._Ctype_unsignedchar"
|
||||
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 33 00 " +
|
||||
" |test3.|\n}"
|
||||
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
|
||||
|
||||
// C signed char array.
|
||||
v4, v4l, v4c := testdata.GetCgoSignedCharArray()
|
||||
v4Len := fmt.Sprintf("%d", v4l)
|
||||
v4Cap := fmt.Sprintf("%d", v4c)
|
||||
v4t := "[6]testdata._Ctype_schar"
|
||||
v4t2 := "testdata._Ctype_schar"
|
||||
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
|
||||
"{\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 101,\n (" + v4t2 +
|
||||
") 115,\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 52,\n (" + v4t2 +
|
||||
") 0\n}"
|
||||
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
|
||||
|
||||
// C uint8_t array.
|
||||
v5, v5l, v5c := testdata.GetCgoUint8tArray()
|
||||
v5Len := fmt.Sprintf("%d", v5l)
|
||||
v5Cap := fmt.Sprintf("%d", v5c)
|
||||
v5t := "[6]testdata._Ctype_uint8_t"
|
||||
v5s := "(len=" + v5Len + " cap=" + v5Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 35 00 " +
|
||||
" |test5.|\n}"
|
||||
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
|
||||
|
||||
// C typedefed unsigned char array.
|
||||
v6, v6l, v6c := testdata.GetCgoTypdefedUnsignedCharArray()
|
||||
v6Len := fmt.Sprintf("%d", v6l)
|
||||
v6Cap := fmt.Sprintf("%d", v6c)
|
||||
v6t := "[6]testdata._Ctype_custom_uchar_t"
|
||||
v6s := "(len=" + v6Len + " cap=" + v6Cap + ") " +
|
||||
"{\n 00000000 74 65 73 74 36 00 " +
|
||||
" |test6.|\n}"
|
||||
addDumpTest(v6, "("+v6t+") "+v6s+"\n")
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when either cgo is not supported or "-tags testcgo" is not added to the go
|
||||
// test command line. This file intentionally does not setup any cgo tests in
|
||||
// this scenario.
|
||||
// +build !cgo !testcgo
|
||||
|
||||
package spew_test
|
||||
|
||||
func addCgoDumpTests() {
|
||||
// Don't add any tests for cgo since this file is only compiled when
|
||||
// there should not be any cgo tests.
|
||||
}
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
type Flag int
|
||||
|
||||
const (
|
||||
flagOne Flag = iota
|
||||
flagTwo
|
||||
)
|
||||
|
||||
var flagStrings = map[Flag]string{
|
||||
flagOne: "flagOne",
|
||||
flagTwo: "flagTwo",
|
||||
}
|
||||
|
||||
func (f Flag) String() string {
|
||||
if s, ok := flagStrings[f]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("Unknown flag (%d)", int(f))
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
flag Flag
|
||||
data uintptr
|
||||
}
|
||||
|
||||
type Foo struct {
|
||||
unexportedField Bar
|
||||
ExportedField map[interface{}]interface{}
|
||||
}
|
||||
|
||||
// This example demonstrates how to use Dump to dump variables to stdout.
|
||||
func ExampleDump() {
|
||||
// The following package level declarations are assumed for this example:
|
||||
/*
|
||||
type Flag int
|
||||
|
||||
const (
|
||||
flagOne Flag = iota
|
||||
flagTwo
|
||||
)
|
||||
|
||||
var flagStrings = map[Flag]string{
|
||||
flagOne: "flagOne",
|
||||
flagTwo: "flagTwo",
|
||||
}
|
||||
|
||||
func (f Flag) String() string {
|
||||
if s, ok := flagStrings[f]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("Unknown flag (%d)", int(f))
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
flag Flag
|
||||
data uintptr
|
||||
}
|
||||
|
||||
type Foo struct {
|
||||
unexportedField Bar
|
||||
ExportedField map[interface{}]interface{}
|
||||
}
|
||||
*/
|
||||
|
||||
// Setup some sample data structures for the example.
|
||||
bar := Bar{Flag(flagTwo), uintptr(0)}
|
||||
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
|
||||
f := Flag(5)
|
||||
b := []byte{
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
|
||||
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
|
||||
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
|
||||
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
|
||||
0x31, 0x32,
|
||||
}
|
||||
|
||||
// Dump!
|
||||
spew.Dump(s1, f, b)
|
||||
|
||||
// Output:
|
||||
// (spew_test.Foo) {
|
||||
// unexportedField: (spew_test.Bar) {
|
||||
// flag: (spew_test.Flag) flagTwo,
|
||||
// data: (uintptr) <nil>
|
||||
// },
|
||||
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
// (string) (len=3) "one": (bool) true
|
||||
// }
|
||||
// }
|
||||
// (spew_test.Flag) Unknown flag (5)
|
||||
// ([]uint8) (len=34 cap=34) {
|
||||
// 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||
// 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||
// 00000020 31 32 |12|
|
||||
// }
|
||||
//
|
||||
}
|
||||
|
||||
// This example demonstrates how to use Printf to display a variable with a
|
||||
// format string and inline formatting.
|
||||
func ExamplePrintf() {
|
||||
// Create a double pointer to a uint 8.
|
||||
ui8 := uint8(5)
|
||||
pui8 := &ui8
|
||||
ppui8 := &pui8
|
||||
|
||||
// Create a circular data type.
|
||||
type circular struct {
|
||||
ui8 uint8
|
||||
c *circular
|
||||
}
|
||||
c := circular{ui8: 1}
|
||||
c.c = &c
|
||||
|
||||
// Print!
|
||||
spew.Printf("ppui8: %v\n", ppui8)
|
||||
spew.Printf("circular: %v\n", c)
|
||||
|
||||
// Output:
|
||||
// ppui8: <**>5
|
||||
// circular: {1 <*>{1 <*><shown>}}
|
||||
}
|
||||
|
||||
// This example demonstrates how to use a ConfigState.
|
||||
func ExampleConfigState() {
|
||||
// Modify the indent level of the ConfigState only. The global
|
||||
// configuration is not modified.
|
||||
scs := spew.ConfigState{Indent: "\t"}
|
||||
|
||||
// Output using the ConfigState instance.
|
||||
v := map[string]int{"one": 1}
|
||||
scs.Printf("v: %v\n", v)
|
||||
scs.Dump(v)
|
||||
|
||||
// Output:
|
||||
// v: map[one:1]
|
||||
// (map[string]int) (len=1) {
|
||||
// (string) (len=3) "one": (int) 1
|
||||
// }
|
||||
}
|
||||
|
||||
// This example demonstrates how to use ConfigState.Dump to dump variables to
|
||||
// stdout
|
||||
func ExampleConfigState_Dump() {
|
||||
// See the top-level Dump example for details on the types used in this
|
||||
// example.
|
||||
|
||||
// Create two ConfigState instances with different indentation.
|
||||
scs := spew.ConfigState{Indent: "\t"}
|
||||
scs2 := spew.ConfigState{Indent: " "}
|
||||
|
||||
// Setup some sample data structures for the example.
|
||||
bar := Bar{Flag(flagTwo), uintptr(0)}
|
||||
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
|
||||
|
||||
// Dump using the ConfigState instances.
|
||||
scs.Dump(s1)
|
||||
scs2.Dump(s1)
|
||||
|
||||
// Output:
|
||||
// (spew_test.Foo) {
|
||||
// unexportedField: (spew_test.Bar) {
|
||||
// flag: (spew_test.Flag) flagTwo,
|
||||
// data: (uintptr) <nil>
|
||||
// },
|
||||
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
// (string) (len=3) "one": (bool) true
|
||||
// }
|
||||
// }
|
||||
// (spew_test.Foo) {
|
||||
// unexportedField: (spew_test.Bar) {
|
||||
// flag: (spew_test.Flag) flagTwo,
|
||||
// data: (uintptr) <nil>
|
||||
// },
|
||||
// ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||
// (string) (len=3) "one": (bool) true
|
||||
// }
|
||||
// }
|
||||
//
|
||||
}
|
||||
|
||||
// This example demonstrates how to use ConfigState.Printf to display a variable
|
||||
// with a format string and inline formatting.
|
||||
func ExampleConfigState_Printf() {
|
||||
// See the top-level Dump example for details on the types used in this
|
||||
// example.
|
||||
|
||||
// Create two ConfigState instances and modify the method handling of the
|
||||
// first ConfigState only.
|
||||
scs := spew.NewDefaultConfig()
|
||||
scs2 := spew.NewDefaultConfig()
|
||||
scs.DisableMethods = true
|
||||
|
||||
// Alternatively
|
||||
// scs := spew.ConfigState{Indent: " ", DisableMethods: true}
|
||||
// scs2 := spew.ConfigState{Indent: " "}
|
||||
|
||||
// This is of type Flag which implements a Stringer and has raw value 1.
|
||||
f := flagTwo
|
||||
|
||||
// Dump using the ConfigState instances.
|
||||
scs.Printf("f: %v\n", f)
|
||||
scs2.Printf("f: %v\n", f)
|
||||
|
||||
// Output:
|
||||
// f: 1
|
||||
// f: flagTwo
|
||||
}
|
||||
|
|
@ -0,0 +1,413 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// supportedFlags is a list of all the character flags supported by fmt package.
|
||||
const supportedFlags = "0-+# "
|
||||
|
||||
// formatState implements the fmt.Formatter interface and contains information
|
||||
// about the state of a formatting operation. The NewFormatter function can
|
||||
// be used to get a new Formatter which can be used directly as arguments
|
||||
// in standard fmt package printing calls.
|
||||
type formatState struct {
|
||||
value interface{}
|
||||
fs fmt.State
|
||||
depth int
|
||||
pointers map[uintptr]int
|
||||
ignoreNextType bool
|
||||
cs *ConfigState
|
||||
}
|
||||
|
||||
// buildDefaultFormat recreates the original format string without precision
|
||||
// and width information to pass in to fmt.Sprintf in the case of an
|
||||
// unrecognized type. Unless new types are added to the language, this
|
||||
// function won't ever be called.
|
||||
func (f *formatState) buildDefaultFormat() (format string) {
|
||||
buf := bytes.NewBuffer(percentBytes)
|
||||
|
||||
for _, flag := range supportedFlags {
|
||||
if f.fs.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteRune('v')
|
||||
|
||||
format = buf.String()
|
||||
return format
|
||||
}
|
||||
|
||||
// constructOrigFormat recreates the original format string including precision
|
||||
// and width information to pass along to the standard fmt package. This allows
|
||||
// automatic deferral of all format strings this package doesn't support.
|
||||
func (f *formatState) constructOrigFormat(verb rune) (format string) {
|
||||
buf := bytes.NewBuffer(percentBytes)
|
||||
|
||||
for _, flag := range supportedFlags {
|
||||
if f.fs.Flag(int(flag)) {
|
||||
buf.WriteRune(flag)
|
||||
}
|
||||
}
|
||||
|
||||
if width, ok := f.fs.Width(); ok {
|
||||
buf.WriteString(strconv.Itoa(width))
|
||||
}
|
||||
|
||||
if precision, ok := f.fs.Precision(); ok {
|
||||
buf.Write(precisionBytes)
|
||||
buf.WriteString(strconv.Itoa(precision))
|
||||
}
|
||||
|
||||
buf.WriteRune(verb)
|
||||
|
||||
format = buf.String()
|
||||
return format
|
||||
}
|
||||
|
||||
// unpackValue returns values inside of non-nil interfaces when possible and
|
||||
// ensures that types for values which have been unpacked from an interface
|
||||
// are displayed when the show types flag is also set.
|
||||
// This is useful for data types like structs, arrays, slices, and maps which
|
||||
// can contain varying types packed inside an interface.
|
||||
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
|
||||
if v.Kind() == reflect.Interface {
|
||||
f.ignoreNextType = false
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// formatPtr handles formatting of pointers by indirecting them as necessary.
|
||||
func (f *formatState) formatPtr(v reflect.Value) {
|
||||
// Display nil if top level pointer is nil.
|
||||
showTypes := f.fs.Flag('#')
|
||||
if v.IsNil() && (!showTypes || f.ignoreNextType) {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Remove pointers at or below the current depth from map used to detect
|
||||
// circular refs.
|
||||
for k, depth := range f.pointers {
|
||||
if depth >= f.depth {
|
||||
delete(f.pointers, k)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep list of all dereferenced pointers to possibly show later.
|
||||
pointerChain := make([]uintptr, 0)
|
||||
|
||||
// Figure out how many levels of indirection there are by derferencing
|
||||
// pointers and unpacking interfaces down the chain while detecting circular
|
||||
// references.
|
||||
nilFound := false
|
||||
cycleFound := false
|
||||
indirects := 0
|
||||
ve := v
|
||||
for ve.Kind() == reflect.Ptr {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
indirects++
|
||||
addr := ve.Pointer()
|
||||
pointerChain = append(pointerChain, addr)
|
||||
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
|
||||
cycleFound = true
|
||||
indirects--
|
||||
break
|
||||
}
|
||||
f.pointers[addr] = f.depth
|
||||
|
||||
ve = ve.Elem()
|
||||
if ve.Kind() == reflect.Interface {
|
||||
if ve.IsNil() {
|
||||
nilFound = true
|
||||
break
|
||||
}
|
||||
ve = ve.Elem()
|
||||
}
|
||||
}
|
||||
|
||||
// Display type or indirection level depending on flags.
|
||||
if showTypes && !f.ignoreNextType {
|
||||
f.fs.Write(openParenBytes)
|
||||
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||
f.fs.Write([]byte(ve.Type().String()))
|
||||
f.fs.Write(closeParenBytes)
|
||||
} else {
|
||||
if nilFound || cycleFound {
|
||||
indirects += strings.Count(ve.Type().String(), "*")
|
||||
}
|
||||
f.fs.Write(openAngleBytes)
|
||||
f.fs.Write([]byte(strings.Repeat("*", indirects)))
|
||||
f.fs.Write(closeAngleBytes)
|
||||
}
|
||||
|
||||
// Display pointer information depending on flags.
|
||||
if f.fs.Flag('+') && (len(pointerChain) > 0) {
|
||||
f.fs.Write(openParenBytes)
|
||||
for i, addr := range pointerChain {
|
||||
if i > 0 {
|
||||
f.fs.Write(pointerChainBytes)
|
||||
}
|
||||
printHexPtr(f.fs, addr)
|
||||
}
|
||||
f.fs.Write(closeParenBytes)
|
||||
}
|
||||
|
||||
// Display dereferenced value.
|
||||
switch {
|
||||
case nilFound == true:
|
||||
f.fs.Write(nilAngleBytes)
|
||||
|
||||
case cycleFound == true:
|
||||
f.fs.Write(circularShortBytes)
|
||||
|
||||
default:
|
||||
f.ignoreNextType = true
|
||||
f.format(ve)
|
||||
}
|
||||
}
|
||||
|
||||
// format is the main workhorse for providing the Formatter interface. It
|
||||
// uses the passed reflect value to figure out what kind of object we are
|
||||
// dealing with and formats it appropriately. It is a recursive function,
|
||||
// however circular data structures are detected and handled properly.
|
||||
func (f *formatState) format(v reflect.Value) {
|
||||
// Handle invalid reflect values immediately.
|
||||
kind := v.Kind()
|
||||
if kind == reflect.Invalid {
|
||||
f.fs.Write(invalidAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle pointers specially.
|
||||
if kind == reflect.Ptr {
|
||||
f.formatPtr(v)
|
||||
return
|
||||
}
|
||||
|
||||
// Print type information unless already handled elsewhere.
|
||||
if !f.ignoreNextType && f.fs.Flag('#') {
|
||||
f.fs.Write(openParenBytes)
|
||||
f.fs.Write([]byte(v.Type().String()))
|
||||
f.fs.Write(closeParenBytes)
|
||||
}
|
||||
f.ignoreNextType = false
|
||||
|
||||
// Call Stringer/error interfaces if they exist and the handle methods
|
||||
// flag is enabled.
|
||||
if !f.cs.DisableMethods {
|
||||
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||
if handled := handleMethods(f.cs, f.fs, v); handled {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case reflect.Invalid:
|
||||
// Do nothing. We should never get here since invalid has already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Bool:
|
||||
printBool(f.fs, v.Bool())
|
||||
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
printInt(f.fs, v.Int(), 10)
|
||||
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
printUint(f.fs, v.Uint(), 10)
|
||||
|
||||
case reflect.Float32:
|
||||
printFloat(f.fs, v.Float(), 32)
|
||||
|
||||
case reflect.Float64:
|
||||
printFloat(f.fs, v.Float(), 64)
|
||||
|
||||
case reflect.Complex64:
|
||||
printComplex(f.fs, v.Complex(), 32)
|
||||
|
||||
case reflect.Complex128:
|
||||
printComplex(f.fs, v.Complex(), 64)
|
||||
|
||||
case reflect.Slice:
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
break
|
||||
}
|
||||
fallthrough
|
||||
|
||||
case reflect.Array:
|
||||
f.fs.Write(openBracketBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
numEntries := v.Len()
|
||||
for i := 0; i < numEntries; i++ {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(v.Index(i)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeBracketBytes)
|
||||
|
||||
case reflect.String:
|
||||
f.fs.Write([]byte(v.String()))
|
||||
|
||||
case reflect.Interface:
|
||||
// The only time we should get here is for nil interfaces due to
|
||||
// unpackValue calls.
|
||||
if v.IsNil() {
|
||||
f.fs.Write(nilAngleBytes)
|
||||
}
|
||||
|
||||
case reflect.Ptr:
|
||||
// Do nothing. We should never get here since pointers have already
|
||||
// been handled above.
|
||||
|
||||
case reflect.Map:
|
||||
f.fs.Write(openMapBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
keys := v.MapKeys()
|
||||
if f.cs.SortKeys {
|
||||
sortValues(keys)
|
||||
}
|
||||
for i, key := range keys {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(key))
|
||||
f.fs.Write(colonBytes)
|
||||
f.ignoreNextType = true
|
||||
f.format(f.unpackValue(v.MapIndex(key)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeMapBytes)
|
||||
|
||||
case reflect.Struct:
|
||||
numFields := v.NumField()
|
||||
f.fs.Write(openBraceBytes)
|
||||
f.depth++
|
||||
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||
f.fs.Write(maxShortBytes)
|
||||
} else {
|
||||
vt := v.Type()
|
||||
for i := 0; i < numFields; i++ {
|
||||
if i > 0 {
|
||||
f.fs.Write(spaceBytes)
|
||||
}
|
||||
vtf := vt.Field(i)
|
||||
if f.fs.Flag('+') || f.fs.Flag('#') {
|
||||
f.fs.Write([]byte(vtf.Name))
|
||||
f.fs.Write(colonBytes)
|
||||
}
|
||||
f.format(f.unpackValue(v.Field(i)))
|
||||
}
|
||||
}
|
||||
f.depth--
|
||||
f.fs.Write(closeBraceBytes)
|
||||
|
||||
case reflect.Uintptr:
|
||||
printHexPtr(f.fs, uintptr(v.Uint()))
|
||||
|
||||
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||
printHexPtr(f.fs, v.Pointer())
|
||||
|
||||
// There were not any other types at the time this code was written, but
|
||||
// fall back to letting the default fmt package handle it if any get added.
|
||||
default:
|
||||
format := f.buildDefaultFormat()
|
||||
if v.CanInterface() {
|
||||
fmt.Fprintf(f.fs, format, v.Interface())
|
||||
} else {
|
||||
fmt.Fprintf(f.fs, format, v.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
|
||||
// details.
|
||||
func (f *formatState) Format(fs fmt.State, verb rune) {
|
||||
f.fs = fs
|
||||
|
||||
// Use standard formatting for verbs that are not v.
|
||||
if verb != 'v' {
|
||||
format := f.constructOrigFormat(verb)
|
||||
fmt.Fprintf(fs, format, f.value)
|
||||
return
|
||||
}
|
||||
|
||||
if f.value == nil {
|
||||
if fs.Flag('#') {
|
||||
fs.Write(interfaceBytes)
|
||||
}
|
||||
fs.Write(nilAngleBytes)
|
||||
return
|
||||
}
|
||||
|
||||
f.format(reflect.ValueOf(f.value))
|
||||
}
|
||||
|
||||
// newFormatter is a helper function to consolidate the logic from the various
|
||||
// public methods which take varying config states.
|
||||
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
|
||||
fs := &formatState{value: v, cs: cs}
|
||||
fs.pointers = make(map[uintptr]int)
|
||||
return fs
|
||||
}
|
||||
|
||||
/*
|
||||
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||
interface. As a result, it integrates cleanly with standard fmt package
|
||||
printing functions. The formatter is useful for inline printing of smaller data
|
||||
types similar to the standard %v format specifier.
|
||||
|
||||
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||
the width and precision arguments (however they will still work on the format
|
||||
specifiers not handled by the custom formatter).
|
||||
|
||||
Typically this function shouldn't be called directly. It is much easier to make
|
||||
use of the custom formatter by calling one of the convenience functions such as
|
||||
Printf, Println, or Fprintf.
|
||||
*/
|
||||
func NewFormatter(v interface{}) fmt.Formatter {
|
||||
return newFormatter(&Config, v)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
/*
|
||||
This test file is part of the spew package rather than than the spew_test
|
||||
package because it needs access to internals to properly test certain cases
|
||||
which are not possible via the public interface since they should never happen.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// dummyFmtState implements a fake fmt.State to use for testing invalid
|
||||
// reflect.Value handling. This is necessary because the fmt package catches
|
||||
// invalid values before invoking the formatter on them.
|
||||
type dummyFmtState struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (dfs *dummyFmtState) Flag(f int) bool {
|
||||
if f == int('+') {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (dfs *dummyFmtState) Precision() (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (dfs *dummyFmtState) Width() (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// TestInvalidReflectValue ensures the dump and formatter code handles an
|
||||
// invalid reflect value properly. This needs access to internal state since it
|
||||
// should never happen in real code and therefore can't be tested via the public
|
||||
// API.
|
||||
func TestInvalidReflectValue(t *testing.T) {
|
||||
i := 1
|
||||
|
||||
// Dump invalid reflect value.
|
||||
v := new(reflect.Value)
|
||||
buf := new(bytes.Buffer)
|
||||
d := dumpState{w: buf, cs: &Config}
|
||||
d.dump(*v)
|
||||
s := buf.String()
|
||||
want := "<invalid>"
|
||||
if s != want {
|
||||
t.Errorf("InvalidReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Formatter invalid reflect value.
|
||||
buf2 := new(dummyFmtState)
|
||||
f := formatState{value: *v, cs: &Config, fs: buf2}
|
||||
f.format(*v)
|
||||
s = buf2.String()
|
||||
want = "<invalid>"
|
||||
if s != want {
|
||||
t.Errorf("InvalidReflectValue #%d got: %s want: %s", i, s, want)
|
||||
}
|
||||
}
|
||||
|
||||
// flagRO, flagKindShift and flagKindWidth indicate various bit flags that the
|
||||
// reflect package uses internally to track kind and state information.
|
||||
const flagRO = 1 << 0
|
||||
const flagKindShift = 4
|
||||
const flagKindWidth = 5
|
||||
|
||||
// changeKind uses unsafe to intentionally change the kind of a reflect.Value to
|
||||
// the maximum kind value which does not exist. This is needed to test the
|
||||
// fallback code which punts to the standard fmt library for new types that
|
||||
// might get added to the language.
|
||||
func changeKind(v *reflect.Value, readOnly bool) {
|
||||
rvf := (*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + offsetFlag))
|
||||
*rvf = *rvf | ((1<<flagKindWidth - 1) << flagKindShift)
|
||||
if readOnly {
|
||||
*rvf |= flagRO
|
||||
} else {
|
||||
*rvf &= ^uintptr(flagRO)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddedReflectValue tests functionaly of the dump and formatter code which
|
||||
// falls back to the standard fmt library for new types that might get added to
|
||||
// the language.
|
||||
func TestAddedReflectValue(t *testing.T) {
|
||||
i := 1
|
||||
|
||||
// Dump using a reflect.Value that is exported.
|
||||
v := reflect.ValueOf(int8(5))
|
||||
changeKind(&v, false)
|
||||
buf := new(bytes.Buffer)
|
||||
d := dumpState{w: buf, cs: &Config}
|
||||
d.dump(v)
|
||||
s := buf.String()
|
||||
want := "(int8) 5"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Dump using a reflect.Value that is not exported.
|
||||
changeKind(&v, true)
|
||||
buf.Reset()
|
||||
d.dump(v)
|
||||
s = buf.String()
|
||||
want = "(int8) <int8 Value>"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Formatter using a reflect.Value that is exported.
|
||||
changeKind(&v, false)
|
||||
buf2 := new(dummyFmtState)
|
||||
f := formatState{value: v, cs: &Config, fs: buf2}
|
||||
f.format(v)
|
||||
s = buf2.String()
|
||||
want = "5"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
|
||||
}
|
||||
i++
|
||||
|
||||
// Formatter using a reflect.Value that is not exported.
|
||||
changeKind(&v, true)
|
||||
buf2.Reset()
|
||||
f = formatState{value: v, cs: &Config, fs: buf2}
|
||||
f.format(v)
|
||||
s = buf2.String()
|
||||
want = "<int8 Value>"
|
||||
if s != want {
|
||||
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
|
||||
}
|
||||
}
|
||||
|
||||
// SortValues makes the internal sortValues function available to the test
|
||||
// package.
|
||||
func SortValues(values []reflect.Value) {
|
||||
sortValues(values)
|
||||
}
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the formatted string as a value that satisfies error. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Errorf(format string, a ...interface{}) (err error) {
|
||||
return fmt.Errorf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprint(w, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintf(w, format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||
// passed with a default Formatter interface returned by NewFormatter. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||
return fmt.Fprintln(w, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Print(a ...interface{}) (n int, err error) {
|
||||
return fmt.Print(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Printf(format string, a ...interface{}) (n int, err error) {
|
||||
return fmt.Printf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the number of bytes written and any write error encountered. See
|
||||
// NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Println(a ...interface{}) (n int, err error) {
|
||||
return fmt.Println(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprint(a ...interface{}) string {
|
||||
return fmt.Sprint(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||
// passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprintf(format string, a ...interface{}) string {
|
||||
return fmt.Sprintf(format, convertArgs(a)...)
|
||||
}
|
||||
|
||||
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||
// were passed with a default Formatter interface returned by NewFormatter. It
|
||||
// returns the resulting string. See NewFormatter for formatting details.
|
||||
//
|
||||
// This function is shorthand for the following syntax:
|
||||
//
|
||||
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||
func Sprintln(a ...interface{}) string {
|
||||
return fmt.Sprintln(convertArgs(a)...)
|
||||
}
|
||||
|
||||
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||
// length with each argument converted to a default spew Formatter interface.
|
||||
func convertArgs(args []interface{}) (formatters []interface{}) {
|
||||
formatters = make([]interface{}, len(args))
|
||||
for index, arg := range args {
|
||||
formatters[index] = NewFormatter(arg)
|
||||
}
|
||||
return formatters
|
||||
}
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
/*
|
||||
* Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
*
|
||||
* Permission to use, copy, modify, and distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*/
|
||||
|
||||
package spew_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// spewFunc is used to identify which public function of the spew package or
|
||||
// ConfigState a test applies to.
|
||||
type spewFunc int
|
||||
|
||||
const (
|
||||
fCSFdump spewFunc = iota
|
||||
fCSFprint
|
||||
fCSFprintf
|
||||
fCSFprintln
|
||||
fCSPrint
|
||||
fCSPrintln
|
||||
fCSSdump
|
||||
fCSSprint
|
||||
fCSSprintf
|
||||
fCSSprintln
|
||||
fCSErrorf
|
||||
fCSNewFormatter
|
||||
fErrorf
|
||||
fFprint
|
||||
fFprintln
|
||||
fPrint
|
||||
fPrintln
|
||||
fSdump
|
||||
fSprint
|
||||
fSprintf
|
||||
fSprintln
|
||||
)
|
||||
|
||||
// Map of spewFunc values to names for pretty printing.
|
||||
var spewFuncStrings = map[spewFunc]string{
|
||||
fCSFdump: "ConfigState.Fdump",
|
||||
fCSFprint: "ConfigState.Fprint",
|
||||
fCSFprintf: "ConfigState.Fprintf",
|
||||
fCSFprintln: "ConfigState.Fprintln",
|
||||
fCSSdump: "ConfigState.Sdump",
|
||||
fCSPrint: "ConfigState.Print",
|
||||
fCSPrintln: "ConfigState.Println",
|
||||
fCSSprint: "ConfigState.Sprint",
|
||||
fCSSprintf: "ConfigState.Sprintf",
|
||||
fCSSprintln: "ConfigState.Sprintln",
|
||||
fCSErrorf: "ConfigState.Errorf",
|
||||
fCSNewFormatter: "ConfigState.NewFormatter",
|
||||
fErrorf: "spew.Errorf",
|
||||
fFprint: "spew.Fprint",
|
||||
fFprintln: "spew.Fprintln",
|
||||
fPrint: "spew.Print",
|
||||
fPrintln: "spew.Println",
|
||||
fSdump: "spew.Sdump",
|
||||
fSprint: "spew.Sprint",
|
||||
fSprintf: "spew.Sprintf",
|
||||
fSprintln: "spew.Sprintln",
|
||||
}
|
||||
|
||||
func (f spewFunc) String() string {
|
||||
if s, ok := spewFuncStrings[f]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("Unknown spewFunc (%d)", int(f))
|
||||
}
|
||||
|
||||
// spewTest is used to describe a test to be performed against the public
|
||||
// functions of the spew package or ConfigState.
|
||||
type spewTest struct {
|
||||
cs *spew.ConfigState
|
||||
f spewFunc
|
||||
format string
|
||||
in interface{}
|
||||
want string
|
||||
}
|
||||
|
||||
// spewTests houses the tests to be performed against the public functions of
|
||||
// the spew package and ConfigState.
|
||||
//
|
||||
// These tests are only intended to ensure the public functions are exercised
|
||||
// and are intentionally not exhaustive of types. The exhaustive type
|
||||
// tests are handled in the dump and format tests.
|
||||
var spewTests []spewTest
|
||||
|
||||
// redirStdout is a helper function to return the standard output from f as a
|
||||
// byte slice.
|
||||
func redirStdout(f func()) ([]byte, error) {
|
||||
tempFile, err := ioutil.TempFile("", "ss-test")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileName := tempFile.Name()
|
||||
defer os.Remove(fileName) // Ignore error
|
||||
|
||||
origStdout := os.Stdout
|
||||
os.Stdout = tempFile
|
||||
f()
|
||||
os.Stdout = origStdout
|
||||
tempFile.Close()
|
||||
|
||||
return ioutil.ReadFile(fileName)
|
||||
}
|
||||
|
||||
func initSpewTests() {
|
||||
// Config states with various settings.
|
||||
scsDefault := spew.NewDefaultConfig()
|
||||
scsNoMethods := &spew.ConfigState{Indent: " ", DisableMethods: true}
|
||||
scsNoPmethods := &spew.ConfigState{Indent: " ", DisablePointerMethods: true}
|
||||
scsMaxDepth := &spew.ConfigState{Indent: " ", MaxDepth: 1}
|
||||
scsContinue := &spew.ConfigState{Indent: " ", ContinueOnMethod: true}
|
||||
|
||||
// Variables for tests on types which implement Stringer interface with and
|
||||
// without a pointer receiver.
|
||||
ts := stringer("test")
|
||||
tps := pstringer("test")
|
||||
|
||||
// depthTester is used to test max depth handling for structs, array, slices
|
||||
// and maps.
|
||||
type depthTester struct {
|
||||
ic indirCir1
|
||||
arr [1]string
|
||||
slice []string
|
||||
m map[string]int
|
||||
}
|
||||
dt := depthTester{indirCir1{nil}, [1]string{"arr"}, []string{"slice"},
|
||||
map[string]int{"one": 1}}
|
||||
|
||||
// Variable for tests on types which implement error interface.
|
||||
te := customError(10)
|
||||
|
||||
spewTests = []spewTest{
|
||||
{scsDefault, fCSFdump, "", int8(127), "(int8) 127\n"},
|
||||
{scsDefault, fCSFprint, "", int16(32767), "32767"},
|
||||
{scsDefault, fCSFprintf, "%v", int32(2147483647), "2147483647"},
|
||||
{scsDefault, fCSFprintln, "", int(2147483647), "2147483647\n"},
|
||||
{scsDefault, fCSPrint, "", int64(9223372036854775807), "9223372036854775807"},
|
||||
{scsDefault, fCSPrintln, "", uint8(255), "255\n"},
|
||||
{scsDefault, fCSSdump, "", uint8(64), "(uint8) 64\n"},
|
||||
{scsDefault, fCSSprint, "", complex(1, 2), "(1+2i)"},
|
||||
{scsDefault, fCSSprintf, "%v", complex(float32(3), 4), "(3+4i)"},
|
||||
{scsDefault, fCSSprintln, "", complex(float64(5), 6), "(5+6i)\n"},
|
||||
{scsDefault, fCSErrorf, "%#v", uint16(65535), "(uint16)65535"},
|
||||
{scsDefault, fCSNewFormatter, "%v", uint32(4294967295), "4294967295"},
|
||||
{scsDefault, fErrorf, "%v", uint64(18446744073709551615), "18446744073709551615"},
|
||||
{scsDefault, fFprint, "", float32(3.14), "3.14"},
|
||||
{scsDefault, fFprintln, "", float64(6.28), "6.28\n"},
|
||||
{scsDefault, fPrint, "", true, "true"},
|
||||
{scsDefault, fPrintln, "", false, "false\n"},
|
||||
{scsDefault, fSdump, "", complex(-10, -20), "(complex128) (-10-20i)\n"},
|
||||
{scsDefault, fSprint, "", complex(-1, -2), "(-1-2i)"},
|
||||
{scsDefault, fSprintf, "%v", complex(float32(-3), -4), "(-3-4i)"},
|
||||
{scsDefault, fSprintln, "", complex(float64(-5), -6), "(-5-6i)\n"},
|
||||
{scsNoMethods, fCSFprint, "", ts, "test"},
|
||||
{scsNoMethods, fCSFprint, "", &ts, "<*>test"},
|
||||
{scsNoMethods, fCSFprint, "", tps, "test"},
|
||||
{scsNoMethods, fCSFprint, "", &tps, "<*>test"},
|
||||
{scsNoPmethods, fCSFprint, "", ts, "stringer test"},
|
||||
{scsNoPmethods, fCSFprint, "", &ts, "<*>stringer test"},
|
||||
{scsNoPmethods, fCSFprint, "", tps, "test"},
|
||||
{scsNoPmethods, fCSFprint, "", &tps, "<*>stringer test"},
|
||||
{scsMaxDepth, fCSFprint, "", dt, "{{<max>} [<max>] [<max>] map[<max>]}"},
|
||||
{scsMaxDepth, fCSFdump, "", dt, "(spew_test.depthTester) {\n" +
|
||||
" ic: (spew_test.indirCir1) {\n <max depth reached>\n },\n" +
|
||||
" arr: ([1]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
|
||||
" slice: ([]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
|
||||
" m: (map[string]int) (len=1) {\n <max depth reached>\n }\n}\n"},
|
||||
{scsContinue, fCSFprint, "", ts, "(stringer test) test"},
|
||||
{scsContinue, fCSFdump, "", ts, "(spew_test.stringer) " +
|
||||
"(len=4) (stringer test) \"test\"\n"},
|
||||
{scsContinue, fCSFprint, "", te, "(error: 10) 10"},
|
||||
{scsContinue, fCSFdump, "", te, "(spew_test.customError) " +
|
||||
"(error: 10) 10\n"},
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpew executes all of the tests described by spewTests.
|
||||
func TestSpew(t *testing.T) {
|
||||
initSpewTests()
|
||||
|
||||
t.Logf("Running %d tests", len(spewTests))
|
||||
for i, test := range spewTests {
|
||||
buf := new(bytes.Buffer)
|
||||
switch test.f {
|
||||
case fCSFdump:
|
||||
test.cs.Fdump(buf, test.in)
|
||||
|
||||
case fCSFprint:
|
||||
test.cs.Fprint(buf, test.in)
|
||||
|
||||
case fCSFprintf:
|
||||
test.cs.Fprintf(buf, test.format, test.in)
|
||||
|
||||
case fCSFprintln:
|
||||
test.cs.Fprintln(buf, test.in)
|
||||
|
||||
case fCSPrint:
|
||||
b, err := redirStdout(func() { test.cs.Print(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fCSPrintln:
|
||||
b, err := redirStdout(func() { test.cs.Println(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fCSSdump:
|
||||
str := test.cs.Sdump(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSSprint:
|
||||
str := test.cs.Sprint(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSSprintf:
|
||||
str := test.cs.Sprintf(test.format, test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSSprintln:
|
||||
str := test.cs.Sprintln(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fCSErrorf:
|
||||
err := test.cs.Errorf(test.format, test.in)
|
||||
buf.WriteString(err.Error())
|
||||
|
||||
case fCSNewFormatter:
|
||||
fmt.Fprintf(buf, test.format, test.cs.NewFormatter(test.in))
|
||||
|
||||
case fErrorf:
|
||||
err := spew.Errorf(test.format, test.in)
|
||||
buf.WriteString(err.Error())
|
||||
|
||||
case fFprint:
|
||||
spew.Fprint(buf, test.in)
|
||||
|
||||
case fFprintln:
|
||||
spew.Fprintln(buf, test.in)
|
||||
|
||||
case fPrint:
|
||||
b, err := redirStdout(func() { spew.Print(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fPrintln:
|
||||
b, err := redirStdout(func() { spew.Println(test.in) })
|
||||
if err != nil {
|
||||
t.Errorf("%v #%d %v", test.f, i, err)
|
||||
continue
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
case fSdump:
|
||||
str := spew.Sdump(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fSprint:
|
||||
str := spew.Sprint(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fSprintf:
|
||||
str := spew.Sprintf(test.format, test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
case fSprintln:
|
||||
str := spew.Sprintln(test.in)
|
||||
buf.WriteString(str)
|
||||
|
||||
default:
|
||||
t.Errorf("%v #%d unrecognized function", test.f, i)
|
||||
continue
|
||||
}
|
||||
s := buf.String()
|
||||
if test.want != s {
|
||||
t.Errorf("ConfigState #%d\n got: %s want: %s", i, s, test.want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) 2013 Dave Collins <dave@davec.name>
|
||||
//
|
||||
// Permission to use, copy, modify, and distribute this software for any
|
||||
// purpose with or without fee is hereby granted, provided that the above
|
||||
// copyright notice and this permission notice appear in all copies.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
|
||||
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||
// when both cgo is supported and "-tags testcgo" is added to the go test
|
||||
// command line. This code should really only be in the dumpcgo_test.go file,
|
||||
// but unfortunately Go will not allow cgo in test files, so this is a
|
||||
// workaround to allow cgo types to be tested. This configuration is used
|
||||
// because spew itself does not require cgo to run even though it does handle
|
||||
// certain cgo types specially. Rather than forcing all clients to require cgo
|
||||
// and an external C compiler just to run the tests, this scheme makes them
|
||||
// optional.
|
||||
// +build cgo,testcgo
|
||||
|
||||
package testdata
|
||||
|
||||
/*
|
||||
#include <stdint.h>
|
||||
typedef unsigned char custom_uchar_t;
|
||||
|
||||
char *ncp = 0;
|
||||
char *cp = "test";
|
||||
char ca[6] = {'t', 'e', 's', 't', '2', '\0'};
|
||||
unsigned char uca[6] = {'t', 'e', 's', 't', '3', '\0'};
|
||||
signed char sca[6] = {'t', 'e', 's', 't', '4', '\0'};
|
||||
uint8_t ui8ta[6] = {'t', 'e', 's', 't', '5', '\0'};
|
||||
custom_uchar_t tuca[6] = {'t', 'e', 's', 't', '6', '\0'};
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// GetCgoNullCharPointer returns a null char pointer via cgo. This is only
|
||||
// used for tests.
|
||||
func GetCgoNullCharPointer() interface{} {
|
||||
return C.ncp
|
||||
}
|
||||
|
||||
// GetCgoCharPointer returns a char pointer via cgo. This is only used for
|
||||
// tests.
|
||||
func GetCgoCharPointer() interface{} {
|
||||
return C.cp
|
||||
}
|
||||
|
||||
// GetCgoCharArray returns a char array via cgo and the array's len and cap.
|
||||
// This is only used for tests.
|
||||
func GetCgoCharArray() (interface{}, int, int) {
|
||||
return C.ca, len(C.ca), cap(C.ca)
|
||||
}
|
||||
|
||||
// GetCgoUnsignedCharArray returns an unsigned char array via cgo and the
|
||||
// array's len and cap. This is only used for tests.
|
||||
func GetCgoUnsignedCharArray() (interface{}, int, int) {
|
||||
return C.uca, len(C.uca), cap(C.uca)
|
||||
}
|
||||
|
||||
// GetCgoSignedCharArray returns a signed char array via cgo and the array's len
|
||||
// and cap. This is only used for tests.
|
||||
func GetCgoSignedCharArray() (interface{}, int, int) {
|
||||
return C.sca, len(C.sca), cap(C.sca)
|
||||
}
|
||||
|
||||
// GetCgoUint8tArray returns a uint8_t array via cgo and the array's len and
|
||||
// cap. This is only used for tests.
|
||||
func GetCgoUint8tArray() (interface{}, int, int) {
|
||||
return C.ui8ta, len(C.ui8ta), cap(C.ui8ta)
|
||||
}
|
||||
|
||||
// GetCgoTypdefedUnsignedCharArray returns a typedefed unsigned char array via
|
||||
// cgo and the array's len and cap. This is only used for tests.
|
||||
func GetCgoTypdefedUnsignedCharArray() (interface{}, int, int) {
|
||||
return C.tuca, len(C.tuca), cap(C.tuca)
|
||||
}
|
||||
|
|
@ -0,0 +1,373 @@
|
|||
Mozilla Public License Version 2.0
|
||||
==================================
|
||||
|
||||
1. Definitions
|
||||
--------------
|
||||
|
||||
1.1. "Contributor"
|
||||
means each individual or legal entity that creates, contributes to
|
||||
the creation of, or owns Covered Software.
|
||||
|
||||
1.2. "Contributor Version"
|
||||
means the combination of the Contributions of others (if any) used
|
||||
by a Contributor and that particular Contributor's Contribution.
|
||||
|
||||
1.3. "Contribution"
|
||||
means Covered Software of a particular Contributor.
|
||||
|
||||
1.4. "Covered Software"
|
||||
means Source Code Form to which the initial Contributor has attached
|
||||
the notice in Exhibit A, the Executable Form of such Source Code
|
||||
Form, and Modifications of such Source Code Form, in each case
|
||||
including portions thereof.
|
||||
|
||||
1.5. "Incompatible With Secondary Licenses"
|
||||
means
|
||||
|
||||
(a) that the initial Contributor has attached the notice described
|
||||
in Exhibit B to the Covered Software; or
|
||||
|
||||
(b) that the Covered Software was made available under the terms of
|
||||
version 1.1 or earlier of the License, but not also under the
|
||||
terms of a Secondary License.
|
||||
|
||||
1.6. "Executable Form"
|
||||
means any form of the work other than Source Code Form.
|
||||
|
||||
1.7. "Larger Work"
|
||||
means a work that combines Covered Software with other material, in
|
||||
a separate file or files, that is not Covered Software.
|
||||
|
||||
1.8. "License"
|
||||
means this document.
|
||||
|
||||
1.9. "Licensable"
|
||||
means having the right to grant, to the maximum extent possible,
|
||||
whether at the time of the initial grant or subsequently, any and
|
||||
all of the rights conveyed by this License.
|
||||
|
||||
1.10. "Modifications"
|
||||
means any of the following:
|
||||
|
||||
(a) any file in Source Code Form that results from an addition to,
|
||||
deletion from, or modification of the contents of Covered
|
||||
Software; or
|
||||
|
||||
(b) any new file in Source Code Form that contains any Covered
|
||||
Software.
|
||||
|
||||
1.11. "Patent Claims" of a Contributor
|
||||
means any patent claim(s), including without limitation, method,
|
||||
process, and apparatus claims, in any patent Licensable by such
|
||||
Contributor that would be infringed, but for the grant of the
|
||||
License, by the making, using, selling, offering for sale, having
|
||||
made, import, or transfer of either its Contributions or its
|
||||
Contributor Version.
|
||||
|
||||
1.12. "Secondary License"
|
||||
means either the GNU General Public License, Version 2.0, the GNU
|
||||
Lesser General Public License, Version 2.1, the GNU Affero General
|
||||
Public License, Version 3.0, or any later versions of those
|
||||
licenses.
|
||||
|
||||
1.13. "Source Code Form"
|
||||
means the form of the work preferred for making modifications.
|
||||
|
||||
1.14. "You" (or "Your")
|
||||
means an individual or a legal entity exercising rights under this
|
||||
License. For legal entities, "You" includes any entity that
|
||||
controls, is controlled by, or is under common control with You. For
|
||||
purposes of this definition, "control" means (a) the power, direct
|
||||
or indirect, to cause the direction or management of such entity,
|
||||
whether by contract or otherwise, or (b) ownership of more than
|
||||
fifty percent (50%) of the outstanding shares or beneficial
|
||||
ownership of such entity.
|
||||
|
||||
2. License Grants and Conditions
|
||||
--------------------------------
|
||||
|
||||
2.1. Grants
|
||||
|
||||
Each Contributor hereby grants You a world-wide, royalty-free,
|
||||
non-exclusive license:
|
||||
|
||||
(a) under intellectual property rights (other than patent or trademark)
|
||||
Licensable by such Contributor to use, reproduce, make available,
|
||||
modify, display, perform, distribute, and otherwise exploit its
|
||||
Contributions, either on an unmodified basis, with Modifications, or
|
||||
as part of a Larger Work; and
|
||||
|
||||
(b) under Patent Claims of such Contributor to make, use, sell, offer
|
||||
for sale, have made, import, and otherwise transfer either its
|
||||
Contributions or its Contributor Version.
|
||||
|
||||
2.2. Effective Date
|
||||
|
||||
The licenses granted in Section 2.1 with respect to any Contribution
|
||||
become effective for each Contribution on the date the Contributor first
|
||||
distributes such Contribution.
|
||||
|
||||
2.3. Limitations on Grant Scope
|
||||
|
||||
The licenses granted in this Section 2 are the only rights granted under
|
||||
this License. No additional rights or licenses will be implied from the
|
||||
distribution or licensing of Covered Software under this License.
|
||||
Notwithstanding Section 2.1(b) above, no patent license is granted by a
|
||||
Contributor:
|
||||
|
||||
(a) for any code that a Contributor has removed from Covered Software;
|
||||
or
|
||||
|
||||
(b) for infringements caused by: (i) Your and any other third party's
|
||||
modifications of Covered Software, or (ii) the combination of its
|
||||
Contributions with other software (except as part of its Contributor
|
||||
Version); or
|
||||
|
||||
(c) under Patent Claims infringed by Covered Software in the absence of
|
||||
its Contributions.
|
||||
|
||||
This License does not grant any rights in the trademarks, service marks,
|
||||
or logos of any Contributor (except as may be necessary to comply with
|
||||
the notice requirements in Section 3.4).
|
||||
|
||||
2.4. Subsequent Licenses
|
||||
|
||||
No Contributor makes additional grants as a result of Your choice to
|
||||
distribute the Covered Software under a subsequent version of this
|
||||
License (see Section 10.2) or under the terms of a Secondary License (if
|
||||
permitted under the terms of Section 3.3).
|
||||
|
||||
2.5. Representation
|
||||
|
||||
Each Contributor represents that the Contributor believes its
|
||||
Contributions are its original creation(s) or it has sufficient rights
|
||||
to grant the rights to its Contributions conveyed by this License.
|
||||
|
||||
2.6. Fair Use
|
||||
|
||||
This License is not intended to limit any rights You have under
|
||||
applicable copyright doctrines of fair use, fair dealing, or other
|
||||
equivalents.
|
||||
|
||||
2.7. Conditions
|
||||
|
||||
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
|
||||
in Section 2.1.
|
||||
|
||||
3. Responsibilities
|
||||
-------------------
|
||||
|
||||
3.1. Distribution of Source Form
|
||||
|
||||
All distribution of Covered Software in Source Code Form, including any
|
||||
Modifications that You create or to which You contribute, must be under
|
||||
the terms of this License. You must inform recipients that the Source
|
||||
Code Form of the Covered Software is governed by the terms of this
|
||||
License, and how they can obtain a copy of this License. You may not
|
||||
attempt to alter or restrict the recipients' rights in the Source Code
|
||||
Form.
|
||||
|
||||
3.2. Distribution of Executable Form
|
||||
|
||||
If You distribute Covered Software in Executable Form then:
|
||||
|
||||
(a) such Covered Software must also be made available in Source Code
|
||||
Form, as described in Section 3.1, and You must inform recipients of
|
||||
the Executable Form how they can obtain a copy of such Source Code
|
||||
Form by reasonable means in a timely manner, at a charge no more
|
||||
than the cost of distribution to the recipient; and
|
||||
|
||||
(b) You may distribute such Executable Form under the terms of this
|
||||
License, or sublicense it under different terms, provided that the
|
||||
license for the Executable Form does not attempt to limit or alter
|
||||
the recipients' rights in the Source Code Form under this License.
|
||||
|
||||
3.3. Distribution of a Larger Work
|
||||
|
||||
You may create and distribute a Larger Work under terms of Your choice,
|
||||
provided that You also comply with the requirements of this License for
|
||||
the Covered Software. If the Larger Work is a combination of Covered
|
||||
Software with a work governed by one or more Secondary Licenses, and the
|
||||
Covered Software is not Incompatible With Secondary Licenses, this
|
||||
License permits You to additionally distribute such Covered Software
|
||||
under the terms of such Secondary License(s), so that the recipient of
|
||||
the Larger Work may, at their option, further distribute the Covered
|
||||
Software under the terms of either this License or such Secondary
|
||||
License(s).
|
||||
|
||||
3.4. Notices
|
||||
|
||||
You may not remove or alter the substance of any license notices
|
||||
(including copyright notices, patent notices, disclaimers of warranty,
|
||||
or limitations of liability) contained within the Source Code Form of
|
||||
the Covered Software, except that You may alter any license notices to
|
||||
the extent required to remedy known factual inaccuracies.
|
||||
|
||||
3.5. Application of Additional Terms
|
||||
|
||||
You may choose to offer, and to charge a fee for, warranty, support,
|
||||
indemnity or liability obligations to one or more recipients of Covered
|
||||
Software. However, You may do so only on Your own behalf, and not on
|
||||
behalf of any Contributor. You must make it absolutely clear that any
|
||||
such warranty, support, indemnity, or liability obligation is offered by
|
||||
You alone, and You hereby agree to indemnify every Contributor for any
|
||||
liability incurred by such Contributor as a result of warranty, support,
|
||||
indemnity or liability terms You offer. You may include additional
|
||||
disclaimers of warranty and limitations of liability specific to any
|
||||
jurisdiction.
|
||||
|
||||
4. Inability to Comply Due to Statute or Regulation
|
||||
---------------------------------------------------
|
||||
|
||||
If it is impossible for You to comply with any of the terms of this
|
||||
License with respect to some or all of the Covered Software due to
|
||||
statute, judicial order, or regulation then You must: (a) comply with
|
||||
the terms of this License to the maximum extent possible; and (b)
|
||||
describe the limitations and the code they affect. Such description must
|
||||
be placed in a text file included with all distributions of the Covered
|
||||
Software under this License. Except to the extent prohibited by statute
|
||||
or regulation, such description must be sufficiently detailed for a
|
||||
recipient of ordinary skill to be able to understand it.
|
||||
|
||||
5. Termination
|
||||
--------------
|
||||
|
||||
5.1. The rights granted under this License will terminate automatically
|
||||
if You fail to comply with any of its terms. However, if You become
|
||||
compliant, then the rights granted under this License from a particular
|
||||
Contributor are reinstated (a) provisionally, unless and until such
|
||||
Contributor explicitly and finally terminates Your grants, and (b) on an
|
||||
ongoing basis, if such Contributor fails to notify You of the
|
||||
non-compliance by some reasonable means prior to 60 days after You have
|
||||
come back into compliance. Moreover, Your grants from a particular
|
||||
Contributor are reinstated on an ongoing basis if such Contributor
|
||||
notifies You of the non-compliance by some reasonable means, this is the
|
||||
first time You have received notice of non-compliance with this License
|
||||
from such Contributor, and You become compliant prior to 30 days after
|
||||
Your receipt of the notice.
|
||||
|
||||
5.2. If You initiate litigation against any entity by asserting a patent
|
||||
infringement claim (excluding declaratory judgment actions,
|
||||
counter-claims, and cross-claims) alleging that a Contributor Version
|
||||
directly or indirectly infringes any patent, then the rights granted to
|
||||
You by any and all Contributors for the Covered Software under Section
|
||||
2.1 of this License shall terminate.
|
||||
|
||||
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
|
||||
end user license agreements (excluding distributors and resellers) which
|
||||
have been validly granted by You or Your distributors under this License
|
||||
prior to termination shall survive termination.
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 6. Disclaimer of Warranty *
|
||||
* ------------------------- *
|
||||
* *
|
||||
* Covered Software is provided under this License on an "as is" *
|
||||
* basis, without warranty of any kind, either expressed, implied, or *
|
||||
* statutory, including, without limitation, warranties that the *
|
||||
* Covered Software is free of defects, merchantable, fit for a *
|
||||
* particular purpose or non-infringing. The entire risk as to the *
|
||||
* quality and performance of the Covered Software is with You. *
|
||||
* Should any Covered Software prove defective in any respect, You *
|
||||
* (not any Contributor) assume the cost of any necessary servicing, *
|
||||
* repair, or correction. This disclaimer of warranty constitutes an *
|
||||
* essential part of this License. No use of any Covered Software is *
|
||||
* authorized under this License except under this disclaimer. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
************************************************************************
|
||||
* *
|
||||
* 7. Limitation of Liability *
|
||||
* -------------------------- *
|
||||
* *
|
||||
* Under no circumstances and under no legal theory, whether tort *
|
||||
* (including negligence), contract, or otherwise, shall any *
|
||||
* Contributor, or anyone who distributes Covered Software as *
|
||||
* permitted above, be liable to You for any direct, indirect, *
|
||||
* special, incidental, or consequential damages of any character *
|
||||
* including, without limitation, damages for lost profits, loss of *
|
||||
* goodwill, work stoppage, computer failure or malfunction, or any *
|
||||
* and all other commercial damages or losses, even if such party *
|
||||
* shall have been informed of the possibility of such damages. This *
|
||||
* limitation of liability shall not apply to liability for death or *
|
||||
* personal injury resulting from such party's negligence to the *
|
||||
* extent applicable law prohibits such limitation. Some *
|
||||
* jurisdictions do not allow the exclusion or limitation of *
|
||||
* incidental or consequential damages, so this exclusion and *
|
||||
* limitation may not apply to You. *
|
||||
* *
|
||||
************************************************************************
|
||||
|
||||
8. Litigation
|
||||
-------------
|
||||
|
||||
Any litigation relating to this License may be brought only in the
|
||||
courts of a jurisdiction where the defendant maintains its principal
|
||||
place of business and such litigation shall be governed by laws of that
|
||||
jurisdiction, without reference to its conflict-of-law provisions.
|
||||
Nothing in this Section shall prevent a party's ability to bring
|
||||
cross-claims or counter-claims.
|
||||
|
||||
9. Miscellaneous
|
||||
----------------
|
||||
|
||||
This License represents the complete agreement concerning the subject
|
||||
matter hereof. If any provision of this License is held to be
|
||||
unenforceable, such provision shall be reformed only to the extent
|
||||
necessary to make it enforceable. Any law or regulation which provides
|
||||
that the language of a contract shall be construed against the drafter
|
||||
shall not be used to construe this License against a Contributor.
|
||||
|
||||
10. Versions of the License
|
||||
---------------------------
|
||||
|
||||
10.1. New Versions
|
||||
|
||||
Mozilla Foundation is the license steward. Except as provided in Section
|
||||
10.3, no one other than the license steward has the right to modify or
|
||||
publish new versions of this License. Each version will be given a
|
||||
distinguishing version number.
|
||||
|
||||
10.2. Effect of New Versions
|
||||
|
||||
You may distribute the Covered Software under the terms of the version
|
||||
of the License under which You originally received the Covered Software,
|
||||
or under the terms of any subsequent version published by the license
|
||||
steward.
|
||||
|
||||
10.3. Modified Versions
|
||||
|
||||
If you create software not governed by this License, and you want to
|
||||
create a new license for such software, you may create and use a
|
||||
modified version of this License if you rename the license and remove
|
||||
any references to the name of the license steward (except to note that
|
||||
such modified license differs from this License).
|
||||
|
||||
10.4. Distributing Source Code Form that is Incompatible With Secondary
|
||||
Licenses
|
||||
|
||||
If You choose to distribute Source Code Form that is Incompatible With
|
||||
Secondary Licenses under the terms of this version of the License, the
|
||||
notice described in Exhibit B of this License must be attached.
|
||||
|
||||
Exhibit A - Source Code Form License Notice
|
||||
-------------------------------------------
|
||||
|
||||
This Source Code Form is subject to the terms of the Mozilla Public
|
||||
License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
If it is not possible or desirable to put the notice in a particular
|
||||
file, then You may include the notice in a location (such as a LICENSE
|
||||
file in a relevant directory) where a recipient would be likely to look
|
||||
for such a notice.
|
||||
|
||||
You may add additional accurate notices of copyright ownership.
|
||||
|
||||
Exhibit B - "Incompatible With Secondary Licenses" Notice
|
||||
---------------------------------------------------------
|
||||
|
||||
This Source Code Form is "Incompatible With Secondary Licenses", as
|
||||
defined by the Mozilla Public License, v. 2.0.
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
# Go-MySQL-Driver
|
||||
|
||||
A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) package
|
||||
|
||||

|
||||
|
||||
**Current tagged Release:** March 2, 2013 (stable beta 4)
|
||||
|
||||
[](https://travis-ci.org/go-sql-driver/mysql) *(master branch)*
|
||||
|
||||
---------------------------------------
|
||||
* [Features](#features)
|
||||
* [Requirements](#requirements)
|
||||
* [Installation](#installation)
|
||||
* [Usage](#usage)
|
||||
* [DSN (Data Source Name)](#dsn-data-source-name)
|
||||
* [Password](#password)
|
||||
* [Protocol](#protocol)
|
||||
* [Address](#address)
|
||||
* [Parameters](#parameters)
|
||||
* [Examples](#examples)
|
||||
* [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support)
|
||||
* [Testing / Development](#testing--development)
|
||||
* [License](#license)
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## Features
|
||||
* Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance")
|
||||
* Native Go implementation. No C-bindings, just pure Go
|
||||
* Connections over TCP/IPv4, TCP/IPv6 or Unix domain sockets
|
||||
* Automatic handling of broken connections
|
||||
* Automatic Connection Pooling *(by database/sql package)*
|
||||
* Supports queries larger than 16MB
|
||||
* Intelligent `LONG DATA` handling in prepared statements
|
||||
* Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support
|
||||
|
||||
## Requirements
|
||||
* Go 1.0.3 or higher
|
||||
* MySQL (Version 4.1 or higher), MariaDB or Percona Server
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## Installation
|
||||
Simple install the package to your [$GOPATH](http://code.google.com/p/go-wiki/wiki/GOPATH "GOPATH") with the [go tool](http://golang.org/cmd/go/ "go command") from shell:
|
||||
```bash
|
||||
$ go get github.com/go-sql-driver/mysql
|
||||
```
|
||||
Make sure [Git is installed](http://git-scm.com/downloads) on your machine and in your system's `PATH`.
|
||||
|
||||
## Usage
|
||||
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](http://golang.org/pkg/database/sql) API then.
|
||||
|
||||
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
|
||||
```go
|
||||
import "database/sql"
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
|
||||
db, e := sql.Open("mysql", "user:password@/dbname?charset=utf8")
|
||||
```
|
||||
|
||||
[Examples are available in our Wiki](https://github.com/go-sql-driver/mysql/wiki/Examples "Go-MySQL-Driver Examples").
|
||||
|
||||
|
||||
### DSN (Data Source Name)
|
||||
|
||||
The Data Source Name has a common format, like e.g. [PEAR DB](http://pear.php.net/manual/en/package.database.db.intro-dsn.php) uses it, but without type-prefix (optional parts marked by squared brackets):
|
||||
```
|
||||
[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN]
|
||||
```
|
||||
|
||||
A DSN in its fullest form:
|
||||
```
|
||||
username:password@protocol(address)/dbname?param=value
|
||||
```
|
||||
|
||||
Except of the databasename, all values are optional. So the minimal DSN is:
|
||||
```
|
||||
/dbname
|
||||
```
|
||||
|
||||
If you do not want to preselect a database, leave `dbname` empty:
|
||||
```
|
||||
/
|
||||
```
|
||||
|
||||
#### Password
|
||||
Passwords can consist of any character. Escaping is **not** necessary.
|
||||
|
||||
#### Protocol
|
||||
See [net.Dial](http://golang.org/pkg/net/#Dial) for more information which networks are available.
|
||||
In general you should use an Unix domain socket if available and TCP otherwise for best performance.
|
||||
|
||||
#### Address
|
||||
For TCP and UDP networks, addresses have the form `host:port`.
|
||||
If `host` is a literal IPv6 address, it must be enclosed in square brackets.
|
||||
The functions [net.JoinHostPort](http://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](http://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
|
||||
|
||||
For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`.
|
||||
|
||||
#### Parameters
|
||||
***Parameters are case-sensitive!***
|
||||
|
||||
Possible Parameters are:
|
||||
* `timeout`: **Driver** side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
|
||||
* `charset`: Sets the charset used for client-server interaction ("SET NAMES `value`"). If multiple charsets are set (separated by a comma), the following charset is used if setting the charset failes. This enables support for `utf8mb4` ([introduced in MySQL 5.5.3](http://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html)) with fallback to `utf8` for older servers (`charset=utf8mb4,utf8`).
|
||||
* `allowAllFiles`: `allowAllFiles=true` disables the file Whitelist for `LOAD DATA LOCAL INFILE` and allows *all* files. *Might be insecure!*
|
||||
|
||||
All other parameters are interpreted as system variables:
|
||||
* `autocommit`: *"SET autocommit=`value`"*
|
||||
* `time_zone`: *"SET time_zone=`value`"*
|
||||
* `tx_isolation`: *"SET [tx_isolation](https://dev.mysql.com/doc/refman/5.5/en/server-system-variables.html#sysvar_tx_isolation)=`value`"*
|
||||
* `param`: *"SET `param`=`value`"*
|
||||
|
||||
#### Examples
|
||||
```
|
||||
user@unix(/path/to/socket)/dbname
|
||||
```
|
||||
|
||||
```
|
||||
user:password@tcp(localhost:5555)/dbname?charset=utf8&autocommit=true
|
||||
```
|
||||
|
||||
```
|
||||
user:password@tcp([de:ad:be:ef::ca:fe]:80)/dbname?charset=utf8mb4,utf8
|
||||
```
|
||||
|
||||
```
|
||||
user:password@/dbname
|
||||
```
|
||||
|
||||
No Database preselected:
|
||||
```
|
||||
user:password@/
|
||||
```
|
||||
|
||||
### `LOAD DATA LOCAL INFILE` support
|
||||
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
|
||||
```go
|
||||
import "github.com/go-sql-driver/mysql"
|
||||
```
|
||||
|
||||
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` (might be insecure).
|
||||
|
||||
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then.
|
||||
|
||||
See also the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation")
|
||||
|
||||
|
||||
## Testing / Development
|
||||
To run the driver tests you may need to adjust the configuration. See [this Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
|
||||
|
||||
Go-MySQL-Driver is not feature-complete yet. Your help is very appreciated. If you want to contribute, you can work on an [open issue](https://github.com/go-sql-driver/mysql/issues?state=open).
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## License
|
||||
Go-MySQL-Driver is licensed under the [Mozilla Public License Version 2.0](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
|
||||
|
||||
Mozilla summarizes the license scope as follows:
|
||||
> MPL: The copyleft applies to any files containing MPLed code.
|
||||
|
||||
|
||||
That means:
|
||||
* You can **use** the **unchanged** source code both in private as also commercial
|
||||
* You **needn't publish** the source code of your library as long the files licensed under the MPL 2.0 are **unchanged**
|
||||
* You **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0)
|
||||
|
||||
Please read the [MPL 2.0 FAQ](http://www.mozilla.org/MPL/2.0/FAQ.html) if you have further questions regarding the license.
|
||||
|
||||
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBufSize = 4096
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
rd io.Reader
|
||||
idx int
|
||||
length int
|
||||
}
|
||||
|
||||
func newBuffer(rd io.Reader) *buffer {
|
||||
return &buffer{
|
||||
buf: make([]byte, defaultBufSize),
|
||||
rd: rd,
|
||||
}
|
||||
}
|
||||
|
||||
// fill reads at least _need_ bytes in the buffer
|
||||
// existing data in the buffer gets lost
|
||||
func (b *buffer) fill(need int) (err error) {
|
||||
b.idx = 0
|
||||
b.length = 0
|
||||
|
||||
var n int
|
||||
for b.length < need {
|
||||
n, err = b.rd.Read(b.buf[b.length:])
|
||||
b.length += n
|
||||
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
return // err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// read len(p) bytes
|
||||
func (b *buffer) read(p []byte) (err error) {
|
||||
need := len(p)
|
||||
|
||||
if b.length < need {
|
||||
if b.length > 0 {
|
||||
copy(p[0:b.length], b.buf[b.idx:])
|
||||
need -= b.length
|
||||
p = p[b.length:]
|
||||
|
||||
b.idx = 0
|
||||
b.length = 0
|
||||
}
|
||||
|
||||
if need >= len(b.buf) {
|
||||
var n int
|
||||
has := 0
|
||||
for err == nil && need > has {
|
||||
n, err = b.rd.Read(p[has:])
|
||||
has += n
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = b.fill(need) // err deferred
|
||||
}
|
||||
|
||||
copy(p, b.buf[b.idx:])
|
||||
b.idx += need
|
||||
b.length -= need
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,228 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type mysqlConn struct {
|
||||
cfg *config
|
||||
flags clientFlag
|
||||
charset byte
|
||||
cipher []byte
|
||||
netConn net.Conn
|
||||
buf *buffer
|
||||
protocol uint8
|
||||
sequence uint8
|
||||
affectedRows uint64
|
||||
insertId uint64
|
||||
maxPacketAllowed int
|
||||
maxWriteSize int
|
||||
}
|
||||
|
||||
type config struct {
|
||||
user string
|
||||
passwd string
|
||||
net string
|
||||
addr string
|
||||
dbname string
|
||||
params map[string]string
|
||||
}
|
||||
|
||||
// Handles parameters set in DSN
|
||||
func (mc *mysqlConn) handleParams() (err error) {
|
||||
for param, val := range mc.cfg.params {
|
||||
switch param {
|
||||
// Charset
|
||||
case "charset":
|
||||
charsets := strings.Split(val, ",")
|
||||
for i := range charsets {
|
||||
// ignore errors here - a charset may not exist
|
||||
err = mc.exec("SET NAMES " + charsets[i])
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// handled elsewhere
|
||||
case "timeout", "allowAllFiles":
|
||||
continue
|
||||
|
||||
// TLS-Encryption
|
||||
case "tls":
|
||||
err = errors.New("TLS-Encryption not implemented yet")
|
||||
return
|
||||
|
||||
// Compression
|
||||
case "compress":
|
||||
err = errors.New("Compression not implemented yet")
|
||||
|
||||
// System Vars
|
||||
default:
|
||||
err = mc.exec("SET " + param + "=" + val + "")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
||||
err := mc.exec("START TRANSACTION")
|
||||
if err == nil {
|
||||
return &mysqlTx{mc}, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Close() (err error) {
|
||||
mc.writeCommandPacket(comQuit)
|
||||
mc.cfg = nil
|
||||
mc.buf = nil
|
||||
mc.netConn.Close()
|
||||
mc.netConn = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt := &mysqlStmt{
|
||||
mc: mc,
|
||||
}
|
||||
|
||||
// Read Result
|
||||
columnCount, err := stmt.readPrepareResultPacket()
|
||||
if err == nil {
|
||||
if stmt.paramCount > 0 {
|
||||
stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if columnCount > 0 {
|
||||
err = stmt.mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
|
||||
return stmt, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if len(args) == 0 { // no args, fastpath
|
||||
mc.affectedRows = 0
|
||||
mc.insertId = 0
|
||||
|
||||
err := mc.exec(query)
|
||||
if err == nil {
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// with args, must use prepared stmt
|
||||
return nil, driver.ErrSkip
|
||||
|
||||
}
|
||||
|
||||
// Internal function to execute commands
|
||||
func (mc *mysqlConn) exec(query string) (err error) {
|
||||
// Send command
|
||||
err = mc.writeCommandPacketStr(comQuery, query)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = mc.readResultSetHeaderPacket()
|
||||
if err == nil && resLen > 0 {
|
||||
err = mc.readUntilEOF()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = mc.readUntilEOF()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
if len(args) == 0 { // no args, fastpath
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comQuery, query)
|
||||
if err == nil {
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := &mysqlRows{mc, false, nil, false}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
rows.columns, err = mc.readColumns(resLen)
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// with args, must use prepared stmt
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
// Gets the value of the given MySQL System Variable
|
||||
func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
|
||||
// Send command
|
||||
err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
|
||||
if err == nil {
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
rows := &mysqlRows{mc, false, nil, false}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
rows.columns, err = mc.readColumns(resLen)
|
||||
}
|
||||
|
||||
dest := make([]driver.Value, resLen)
|
||||
err = rows.readRow(dest)
|
||||
if err == nil {
|
||||
val = dest[0].([]byte)
|
||||
err = mc.readUntilEOF()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
const (
|
||||
minProtocolVersion byte = 10
|
||||
maxPacketSize = 1<<24 - 1
|
||||
timeFormat = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
// MySQL constants documentation:
|
||||
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||||
|
||||
const (
|
||||
iOK byte = 0x00
|
||||
iLocalInFile byte = 0xfb
|
||||
iEOF byte = 0xfe
|
||||
iERR byte = 0xff
|
||||
)
|
||||
|
||||
type clientFlag uint32
|
||||
|
||||
const (
|
||||
clientLongPassword clientFlag = 1 << iota
|
||||
clientFoundRows
|
||||
clientLongFlag
|
||||
clientConnectWithDB
|
||||
clientNoSchema
|
||||
clientCompress
|
||||
clientODBC
|
||||
clientLocalFiles
|
||||
clientIgnoreSpace
|
||||
clientProtocol41
|
||||
clientInteractive
|
||||
clientSSL
|
||||
clientIgnoreSIGPIPE
|
||||
clientTransactions
|
||||
clientReserved
|
||||
clientSecureConn
|
||||
clientMultiStatements
|
||||
clientMultiResults
|
||||
)
|
||||
|
||||
const (
|
||||
comQuit byte = iota + 1
|
||||
comInitDB
|
||||
comQuery
|
||||
comFieldList
|
||||
comCreateDB
|
||||
comDropDB
|
||||
comRefresh
|
||||
comShutdown
|
||||
comStatistics
|
||||
comProcessInfo
|
||||
comConnect
|
||||
comProcessKill
|
||||
comDebug
|
||||
comPing
|
||||
comTime
|
||||
comDelayedInsert
|
||||
comChangeUser
|
||||
comBinlogDump
|
||||
comTableDump
|
||||
comConnectOut
|
||||
comRegiserSlave
|
||||
comStmtPrepare
|
||||
comStmtExecute
|
||||
comStmtSendLongData
|
||||
comStmtClose
|
||||
comStmtReset
|
||||
comSetOption
|
||||
comStmtFetch
|
||||
)
|
||||
|
||||
const (
|
||||
fieldTypeDecimal byte = iota
|
||||
fieldTypeTiny
|
||||
fieldTypeShort
|
||||
fieldTypeLong
|
||||
fieldTypeFloat
|
||||
fieldTypeDouble
|
||||
fieldTypeNULL
|
||||
fieldTypeTimestamp
|
||||
fieldTypeLongLong
|
||||
fieldTypeInt24
|
||||
fieldTypeDate
|
||||
fieldTypeTime
|
||||
fieldTypeDateTime
|
||||
fieldTypeYear
|
||||
fieldTypeNewDate
|
||||
fieldTypeVarChar
|
||||
fieldTypeBit
|
||||
)
|
||||
const (
|
||||
fieldTypeNewDecimal byte = iota + 0xf6
|
||||
fieldTypeEnum
|
||||
fieldTypeSet
|
||||
fieldTypeTinyBLOB
|
||||
fieldTypeMediumBLOB
|
||||
fieldTypeLongBLOB
|
||||
fieldTypeBLOB
|
||||
fieldTypeVarString
|
||||
fieldTypeString
|
||||
fieldTypeGeometry
|
||||
)
|
||||
|
||||
type fieldFlag uint16
|
||||
|
||||
const (
|
||||
flagNotNULL fieldFlag = 1 << iota
|
||||
flagPriKey
|
||||
flagUniqueKey
|
||||
flagMultipleKey
|
||||
flagBLOB
|
||||
flagUnsigned
|
||||
flagZeroFill
|
||||
flagBinary
|
||||
flagEnum
|
||||
flagAutoIncrement
|
||||
flagTimestamp
|
||||
flagSet
|
||||
flagUnknown1
|
||||
flagUnknown2
|
||||
flagUnknown3
|
||||
flagUnknown4
|
||||
)
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mysqlDriver struct{}
|
||||
|
||||
// Open new Connection.
|
||||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
|
||||
// the DSN string is formated
|
||||
func (d *mysqlDriver) Open(dsn string) (driver.Conn, error) {
|
||||
var err error
|
||||
|
||||
// New mysqlConn
|
||||
mc := &mysqlConn{
|
||||
cfg: parseDSN(dsn),
|
||||
maxPacketAllowed: maxPacketSize,
|
||||
maxWriteSize: maxPacketSize - 1,
|
||||
}
|
||||
|
||||
// Connect to Server
|
||||
if _, ok := mc.cfg.params["timeout"]; ok { // with timeout
|
||||
var timeout time.Duration
|
||||
timeout, err = time.ParseDuration(mc.cfg.params["timeout"])
|
||||
if err == nil {
|
||||
mc.netConn, err = net.DialTimeout(mc.cfg.net, mc.cfg.addr, timeout)
|
||||
}
|
||||
} else { // no timeout
|
||||
mc.netConn, err = net.Dial(mc.cfg.net, mc.cfg.addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mc.buf = newBuffer(mc.netConn)
|
||||
|
||||
// Reading Handshake Initialization Packet
|
||||
err = mc.readInitPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
err = mc.writeAuthPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read Result Packet
|
||||
err = mc.readResultOK()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get max allowed packet size
|
||||
maxap, err := mc.getSystemVar("max_allowed_packet")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mc.maxPacketAllowed = stringToInt(maxap) - 1
|
||||
if mc.maxPacketAllowed < maxPacketSize {
|
||||
mc.maxWriteSize = mc.maxPacketAllowed
|
||||
}
|
||||
|
||||
// Handle DSN Params
|
||||
err = mc.handleParams()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mc, err
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("mysql", &mysqlDriver{})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,20 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
errMalformPkt = errors.New("Malformed Packet")
|
||||
errPktSync = errors.New("Commands out of sync. You can't run this command now")
|
||||
errPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
|
||||
errOldPassword = errors.New("It seems like you are using old_passwords, which is unsupported. See https://github.com/go-sql-driver/mysql/wiki/old_passwords")
|
||||
errPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
|
||||
)
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
fileRegister map[string]bool
|
||||
readerRegister map[string]func() io.Reader
|
||||
)
|
||||
|
||||
func init() {
|
||||
fileRegister = make(map[string]bool)
|
||||
readerRegister = make(map[string]func() io.Reader)
|
||||
}
|
||||
|
||||
// RegisterLocalFile adds the given file to the file whitelist,
|
||||
// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
|
||||
// Alternatively you can allow the use of all local files with
|
||||
// the DSN parameter 'allowAllFiles=true'
|
||||
func RegisterLocalFile(filepath string) {
|
||||
fileRegister[strings.Trim(filepath, `"`)] = true
|
||||
}
|
||||
|
||||
// RegisterReaderHandler registers a handler function which is used
|
||||
// to receive a io.Reader.
|
||||
// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
|
||||
// If the handler returns a io.ReadCloser Close() is called when the
|
||||
// request is finished.
|
||||
func RegisterReaderHandler(name string, handler func() io.Reader) {
|
||||
readerRegister[name] = handler
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
||||
var rdr io.Reader
|
||||
data := make([]byte, 4+mc.maxWriteSize)
|
||||
|
||||
if strings.HasPrefix(name, "Reader::") { // io.Reader
|
||||
name = name[8:]
|
||||
handler, inMap := readerRegister[name]
|
||||
if handler != nil {
|
||||
rdr = handler()
|
||||
}
|
||||
if rdr == nil {
|
||||
if !inMap {
|
||||
err = fmt.Errorf("Reader '%s' is not registered", name)
|
||||
} else {
|
||||
err = fmt.Errorf("Reader '%s' is <nil>", name)
|
||||
}
|
||||
}
|
||||
} else { // File
|
||||
name = strings.Trim(name, `"`)
|
||||
if fileRegister[name] || mc.cfg.params[`allowAllFiles`] == `true` {
|
||||
rdr, err = os.Open(name)
|
||||
} else {
|
||||
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
|
||||
}
|
||||
}
|
||||
|
||||
if rdc, ok := rdr.(io.ReadCloser); ok {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
err = rdc.Close()
|
||||
} else {
|
||||
rdc.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// send content packets
|
||||
var ioErr error
|
||||
if err == nil {
|
||||
var n int
|
||||
for err == nil && ioErr == nil {
|
||||
n, err = rdr.Read(data[4:])
|
||||
if n > 0 {
|
||||
data[0] = byte(n)
|
||||
data[1] = byte(n >> 8)
|
||||
data[2] = byte(n >> 16)
|
||||
data[3] = mc.sequence
|
||||
ioErr = mc.writePacket(data[:4+n])
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
if ioErr != nil {
|
||||
errLog.Print(ioErr.Error())
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
}
|
||||
|
||||
// send empty packet (termination)
|
||||
ioErr = mc.writePacket([]byte{
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
mc.sequence,
|
||||
})
|
||||
if ioErr != nil {
|
||||
errLog.Print(ioErr.Error())
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
// read OK packet
|
||||
if err == nil {
|
||||
return mc.readResultOK()
|
||||
} else {
|
||||
mc.readPacket()
|
||||
}
|
||||
return err
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,23 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
type mysqlResult struct {
|
||||
affectedRows int64
|
||||
insertId int64
|
||||
}
|
||||
|
||||
func (res *mysqlResult) LastInsertId() (int64, error) {
|
||||
return res.insertId, nil
|
||||
}
|
||||
|
||||
func (res *mysqlResult) RowsAffected() (int64, error) {
|
||||
return res.affectedRows, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
type mysqlField struct {
|
||||
name string
|
||||
fieldType byte
|
||||
flags fieldFlag
|
||||
}
|
||||
|
||||
type mysqlRows struct {
|
||||
mc *mysqlConn
|
||||
binary bool
|
||||
columns []mysqlField
|
||||
eof bool
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Columns() (columns []string) {
|
||||
columns = make([]string, len(rows.columns))
|
||||
for i := range columns {
|
||||
columns[i] = rows.columns[i].name
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Close() (err error) {
|
||||
defer func() {
|
||||
rows.mc = nil
|
||||
}()
|
||||
|
||||
// Remove unread packets from stream
|
||||
if !rows.eof {
|
||||
if rows.mc == nil {
|
||||
return errors.New("Invalid Connection")
|
||||
}
|
||||
|
||||
err = rows.mc.readUntilEOF()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Next(dest []driver.Value) error {
|
||||
if rows.eof {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
if rows.mc == nil {
|
||||
return errors.New("Invalid Connection")
|
||||
}
|
||||
|
||||
// Fetch next row from stream
|
||||
var err error
|
||||
if rows.binary {
|
||||
err = rows.readBinaryRow(dest)
|
||||
} else {
|
||||
err = rows.readRow(dest)
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
rows.eof = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
type mysqlStmt struct {
|
||||
mc *mysqlConn
|
||||
id uint32
|
||||
paramCount int
|
||||
params []mysqlField
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Close() (err error) {
|
||||
err = stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
|
||||
stmt.mc = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) NumInput() int {
|
||||
return stmt.paramCount
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
stmt.mc.affectedRows = 0
|
||||
stmt.mc.insertId = 0
|
||||
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = stmt.mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
err = stmt.mc.readUntilEOF()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rows
|
||||
err = stmt.mc.readUntilEOF()
|
||||
}
|
||||
if err == nil {
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(stmt.mc.affectedRows),
|
||||
insertId: int64(stmt.mc.insertId),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read Result
|
||||
var resLen int
|
||||
resLen, err = stmt.mc.readResultSetHeaderPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows := &mysqlRows{stmt.mc, true, nil, false}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
rows.columns, err = stmt.mc.readColumns(resLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return rows, err
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
type mysqlTx struct {
|
||||
mc *mysqlConn
|
||||
}
|
||||
|
||||
func (tx *mysqlTx) Commit() (err error) {
|
||||
err = tx.mc.exec("COMMIT")
|
||||
tx.mc = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (tx *mysqlTx) Rollback() (err error) {
|
||||
err = tx.mc.exec("ROLLBACK")
|
||||
tx.mc = nil
|
||||
return
|
||||
}
|
||||
|
|
@ -0,0 +1,242 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Logger
|
||||
var (
|
||||
errLog *log.Logger
|
||||
)
|
||||
|
||||
func init() {
|
||||
errLog = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
|
||||
|
||||
dsnPattern = regexp.MustCompile(
|
||||
`^(?:(?P<user>.*?)(?::(?P<passwd>.*))?@)?` + // [user[:password]@]
|
||||
`(?:(?P<net>[^\(]*)(?:\((?P<addr>[^\)]*)\))?)?` + // [net[(addr)]]
|
||||
`\/(?P<dbname>.*?)` + // /dbname
|
||||
`(?:\?(?P<params>[^\?]*))?$`) // [?param1=value1¶mN=valueN]
|
||||
}
|
||||
|
||||
// Data Source Name Parser
|
||||
var dsnPattern *regexp.Regexp
|
||||
|
||||
func parseDSN(dsn string) *config {
|
||||
cfg := new(config)
|
||||
cfg.params = make(map[string]string)
|
||||
|
||||
matches := dsnPattern.FindStringSubmatch(dsn)
|
||||
names := dsnPattern.SubexpNames()
|
||||
|
||||
for i, match := range matches {
|
||||
switch names[i] {
|
||||
case "user":
|
||||
cfg.user = match
|
||||
case "passwd":
|
||||
cfg.passwd = match
|
||||
case "net":
|
||||
cfg.net = match
|
||||
case "addr":
|
||||
cfg.addr = match
|
||||
case "dbname":
|
||||
cfg.dbname = match
|
||||
case "params":
|
||||
for _, v := range strings.Split(match, "&") {
|
||||
param := strings.SplitN(v, "=", 2)
|
||||
if len(param) != 2 {
|
||||
continue
|
||||
}
|
||||
cfg.params[param[0]] = param[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set default network if empty
|
||||
if cfg.net == "" {
|
||||
cfg.net = "tcp"
|
||||
}
|
||||
|
||||
// Set default adress if empty
|
||||
if cfg.addr == "" {
|
||||
cfg.addr = "127.0.0.1:3306"
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Encrypt password using 4.1+ method
|
||||
// http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol#4.1_and_later
|
||||
func scramblePassword(scramble, password []byte) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stage1Hash = SHA1(password)
|
||||
crypt := sha1.New()
|
||||
crypt.Write(password)
|
||||
stage1 := crypt.Sum(nil)
|
||||
|
||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||
// inner Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(stage1)
|
||||
hash := crypt.Sum(nil)
|
||||
|
||||
// outer Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(scramble)
|
||||
crypt.Write(hash)
|
||||
scramble = crypt.Sum(nil)
|
||||
|
||||
// token = scrambleHash XOR stage1Hash
|
||||
for i := range scramble {
|
||||
scramble[i] ^= stage1[i]
|
||||
}
|
||||
return scramble
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Convert from and to bytes *
|
||||
******************************************************************************/
|
||||
|
||||
func uint64ToBytes(n uint64) []byte {
|
||||
return []byte{
|
||||
byte(n),
|
||||
byte(n >> 8),
|
||||
byte(n >> 16),
|
||||
byte(n >> 24),
|
||||
byte(n >> 32),
|
||||
byte(n >> 40),
|
||||
byte(n >> 48),
|
||||
byte(n >> 56),
|
||||
}
|
||||
}
|
||||
|
||||
func uint64ToString(n uint64) []byte {
|
||||
var a [20]byte
|
||||
i := 20
|
||||
|
||||
// U+0030 = 0
|
||||
// ...
|
||||
// U+0039 = 9
|
||||
|
||||
var q uint64
|
||||
for n >= 10 {
|
||||
i--
|
||||
q = n / 10
|
||||
a[i] = uint8(n-q*10) + 0x30
|
||||
n = q
|
||||
}
|
||||
|
||||
i--
|
||||
a[i] = uint8(n) + 0x30
|
||||
|
||||
return a[i:]
|
||||
}
|
||||
|
||||
// treats string value as unsigned integer representation
|
||||
func stringToInt(b []byte) int {
|
||||
val := 0
|
||||
for i := range b {
|
||||
val *= 10
|
||||
val += int(b[i] - 0x30)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func readLengthEnodedString(b []byte) ([]byte, bool, int, error) {
|
||||
// Get length
|
||||
num, isNull, n := readLengthEncodedInteger(b)
|
||||
if num < 1 {
|
||||
return nil, isNull, n, nil
|
||||
}
|
||||
|
||||
n += int(num)
|
||||
|
||||
// Check data length
|
||||
if len(b) >= n {
|
||||
return b[n-int(num) : n], false, n, nil
|
||||
}
|
||||
return nil, false, n, io.EOF
|
||||
}
|
||||
|
||||
func skipLengthEnodedString(b []byte) (int, error) {
|
||||
// Get length
|
||||
num, _, n := readLengthEncodedInteger(b)
|
||||
if num < 1 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
n += int(num)
|
||||
|
||||
// Check data length
|
||||
if len(b) >= n {
|
||||
return n, nil
|
||||
}
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
func readLengthEncodedInteger(b []byte) (num uint64, isNull bool, n int) {
|
||||
switch b[0] {
|
||||
|
||||
// 251: NULL
|
||||
case 0xfb:
|
||||
n = 1
|
||||
isNull = true
|
||||
return
|
||||
|
||||
// 252: value of following 2
|
||||
case 0xfc:
|
||||
num = uint64(b[1]) | uint64(b[2])<<8
|
||||
n = 3
|
||||
return
|
||||
|
||||
// 253: value of following 3
|
||||
case 0xfd:
|
||||
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16
|
||||
n = 4
|
||||
return
|
||||
|
||||
// 254: value of following 8
|
||||
case 0xfe:
|
||||
num = uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 |
|
||||
uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 |
|
||||
uint64(b[7])<<48 | uint64(b[8])<<54
|
||||
n = 9
|
||||
return
|
||||
}
|
||||
|
||||
// 0-250: value of first byte
|
||||
num = uint64(b[0])
|
||||
n = 1
|
||||
return
|
||||
}
|
||||
|
||||
func lengthEncodedIntegerToBytes(n uint64) []byte {
|
||||
switch {
|
||||
case n <= 250:
|
||||
return []byte{byte(n)}
|
||||
|
||||
case n <= 0xffff:
|
||||
return []byte{0xfc, byte(n), byte(n >> 8)}
|
||||
|
||||
case n <= 0xffffff:
|
||||
return []byte{0xfd, byte(n), byte(n >> 8), byte(n >> 16)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2013 Julien Schmidt. All rights reserved.
|
||||
// http://www.julienschmidt.com
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testDSNs = []struct {
|
||||
in string
|
||||
out string
|
||||
}{
|
||||
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value]}"},
|
||||
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8]}"},
|
||||
{"user:password@tcp(localhost:5555)/dbname?charset=utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8]}"},
|
||||
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8]}"},
|
||||
{"user:password@/dbname", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[]}"},
|
||||
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[]}"},
|
||||
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[]}"},
|
||||
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[]}"},
|
||||
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[]}"},
|
||||
}
|
||||
|
||||
func TestDSNParser(t *testing.T) {
|
||||
var cfg *config
|
||||
var res string
|
||||
|
||||
for i, tst := range testDSNs {
|
||||
cfg = parseDSN(tst.in)
|
||||
res = fmt.Sprintf("%+v", cfg)
|
||||
if res != tst.out {
|
||||
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, tst.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2016 The Xorm Authors
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the {organization} nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
# SQL builder
|
||||
|
||||
[](https://circleci.com/gh/go-xorm/builder/tree/master)
|
||||
|
||||
Package builder is a lightweight and fast SQL builder for Go and XORM.
|
||||
|
||||
Make sure you have installed Go 1.1+ and then:
|
||||
|
||||
go get github.com/go-xorm/builder
|
||||
|
||||
# Insert
|
||||
|
||||
```Go
|
||||
sql, args, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
|
||||
```
|
||||
|
||||
# Select
|
||||
|
||||
```Go
|
||||
sql, args, err := Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
|
||||
|
||||
sql, args, err = Select("c, d").From("table1").LeftJoin("table2", Eq{"table1.id": 1}.And(Lt{"table2.id": 3})).
|
||||
RightJoin("table3", "table2.id = table3.tid").Where(Eq{"a": 1}).ToSQL()
|
||||
```
|
||||
|
||||
# Update
|
||||
|
||||
```Go
|
||||
sql, args, err := Update(Eq{"a": 2}).From("table1").Where(Eq{"a": 1}).ToSQL()
|
||||
```
|
||||
|
||||
# Delete
|
||||
|
||||
```Go
|
||||
sql, args, err := Delete(Eq{"a": 1}).From("table1").ToSQL()
|
||||
```
|
||||
|
||||
# Conditions
|
||||
|
||||
* `Eq` is a redefine of a map, you can give one or more conditions to `Eq`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Eq{"a":1})
|
||||
// a=? [1]
|
||||
sql, args, _ := ToSQL(Eq{"b":"c"}.And(Eq{"c": 0}))
|
||||
// b=? AND c=? ["c", 0]
|
||||
sql, args, _ := ToSQL(Eq{"b":"c", "c":0})
|
||||
// b=? AND c=? ["c", 0]
|
||||
sql, args, _ := ToSQL(Eq{"b":"c"}.Or(Eq{"b":"d"}))
|
||||
// b=? OR b=? ["c", "d"]
|
||||
sql, args, _ := ToSQL(Eq{"b": []string{"c", "d"}})
|
||||
// b IN (?,?) ["c", "d"]
|
||||
sql, args, _ := ToSQL(Eq{"b": 1, "c":[]int{2, 3}})
|
||||
// b=? AND c IN (?,?) [1, 2, 3]
|
||||
```
|
||||
|
||||
* `Neq` is the same to `Eq`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Neq{"a":1})
|
||||
// a<>? [1]
|
||||
sql, args, _ := ToSQL(Neq{"b":"c"}.And(Neq{"c": 0}))
|
||||
// b<>? AND c<>? ["c", 0]
|
||||
sql, args, _ := ToSQL(Neq{"b":"c", "c":0})
|
||||
// b<>? AND c<>? ["c", 0]
|
||||
sql, args, _ := ToSQL(Neq{"b":"c"}.Or(Neq{"b":"d"}))
|
||||
// b<>? OR b<>? ["c", "d"]
|
||||
sql, args, _ := ToSQL(Neq{"b": []string{"c", "d"}})
|
||||
// b NOT IN (?,?) ["c", "d"]
|
||||
sql, args, _ := ToSQL(Neq{"b": 1, "c":[]int{2, 3}})
|
||||
// b<>? AND c NOT IN (?,?) [1, 2, 3]
|
||||
```
|
||||
|
||||
* `Gt`, `Gte`, `Lt`, `Lte`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Gt{"a", 1}.And(Gte{"b", 2}))
|
||||
// a>? AND b>=? [1, 2]
|
||||
sql, args, _ := ToSQL(Lt{"a", 1}.Or(Lte{"b", 2}))
|
||||
// a<? OR b<=? [1, 2]
|
||||
```
|
||||
|
||||
* `Like`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Like{"a", "c"})
|
||||
// a LIKE ? [%c%]
|
||||
```
|
||||
|
||||
* `Expr` you can customerize your sql with `Expr`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Expr("a = ? ", 1))
|
||||
// a = ? [1]
|
||||
sql, args, _ := ToSQL(Eq{"a": Expr("select id from table where c = ?", 1)})
|
||||
// a=(select id from table where c = ?) [1]
|
||||
```
|
||||
|
||||
* `In` and `NotIn`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(In("a", 1, 2, 3))
|
||||
// a IN (?,?,?) [1,2,3]
|
||||
sql, args, _ := ToSQL(In("a", []int{1, 2, 3}))
|
||||
// a IN (?,?,?) [1,2,3]
|
||||
sql, args, _ := ToSQL(In("a", Expr("select id from b where c = ?", 1))))
|
||||
// a IN (select id from b where c = ?) [1]
|
||||
```
|
||||
|
||||
* `IsNull` and `NotNull`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(IsNull{"a"})
|
||||
// a IS NULL []
|
||||
sql, args, _ := ToSQL(NotNull{"b"})
|
||||
// b IS NOT NULL []
|
||||
```
|
||||
|
||||
* `And(conds ...Cond)`, And can connect one or more condtions via And
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(And(Eq{"a":1}, Like{"b", "c"}, Neq{"d", 2}))
|
||||
// a=? AND b LIKE ? AND d<>? [1, %c%, 2]
|
||||
```
|
||||
|
||||
* `Or(conds ...Cond)`, Or can connect one or more conditions via Or
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Or(Eq{"a":1}, Like{"b", "c"}, Neq{"d", 2}))
|
||||
// a=? OR b LIKE ? OR d<>? [1, %c%, 2]
|
||||
sql, args, _ := ToSQL(Or(Eq{"a":1}, And(Like{"b", "c"}, Neq{"d", 2})))
|
||||
// a=? OR (b LIKE ? AND d<>?) [1, %c%, 2]
|
||||
```
|
||||
|
||||
* `Between`
|
||||
|
||||
```Go
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Between{"a", 1, 2})
|
||||
// a BETWEEN 1 AND 2
|
||||
```
|
||||
|
||||
* Define yourself conditions
|
||||
|
||||
Since `Cond` is an interface.
|
||||
|
||||
```Go
|
||||
type Cond interface {
|
||||
WriteTo(Writer) error
|
||||
And(...Cond) Cond
|
||||
Or(...Cond) Cond
|
||||
IsValid() bool
|
||||
}
|
||||
```
|
||||
|
||||
You can define yourself conditions and compose with other `Cond`.
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
type optype byte
|
||||
|
||||
const (
|
||||
condType optype = iota // only conditions
|
||||
selectType // select
|
||||
insertType // insert
|
||||
updateType // update
|
||||
deleteType // delete
|
||||
)
|
||||
|
||||
type join struct {
|
||||
joinType string
|
||||
joinTable string
|
||||
joinCond Cond
|
||||
}
|
||||
|
||||
// Builder describes a SQL statement
|
||||
type Builder struct {
|
||||
optype
|
||||
tableName string
|
||||
cond Cond
|
||||
selects []string
|
||||
joins []join
|
||||
inserts Eq
|
||||
updates []Eq
|
||||
}
|
||||
|
||||
// Select creates a select Builder
|
||||
func Select(cols ...string) *Builder {
|
||||
builder := &Builder{cond: NewCond()}
|
||||
return builder.Select(cols...)
|
||||
}
|
||||
|
||||
// Insert creates an insert Builder
|
||||
func Insert(eq Eq) *Builder {
|
||||
builder := &Builder{cond: NewCond()}
|
||||
return builder.Insert(eq)
|
||||
}
|
||||
|
||||
// Update creates an update Builder
|
||||
func Update(updates ...Eq) *Builder {
|
||||
builder := &Builder{cond: NewCond()}
|
||||
return builder.Update(updates...)
|
||||
}
|
||||
|
||||
// Delete creates a delete Builder
|
||||
func Delete(conds ...Cond) *Builder {
|
||||
builder := &Builder{cond: NewCond()}
|
||||
return builder.Delete(conds...)
|
||||
}
|
||||
|
||||
// Where sets where SQL
|
||||
func (b *Builder) Where(cond Cond) *Builder {
|
||||
b.cond = b.cond.And(cond)
|
||||
return b
|
||||
}
|
||||
|
||||
// From sets the table name
|
||||
func (b *Builder) From(tableName string) *Builder {
|
||||
b.tableName = tableName
|
||||
return b
|
||||
}
|
||||
|
||||
// Into sets insert table name
|
||||
func (b *Builder) Into(tableName string) *Builder {
|
||||
b.tableName = tableName
|
||||
return b
|
||||
}
|
||||
|
||||
// Join sets join table and contions
|
||||
func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder {
|
||||
switch joinCond.(type) {
|
||||
case Cond:
|
||||
b.joins = append(b.joins, join{joinType, joinTable, joinCond.(Cond)})
|
||||
case string:
|
||||
b.joins = append(b.joins, join{joinType, joinTable, Expr(joinCond.(string))})
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// InnerJoin sets inner join
|
||||
func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder {
|
||||
return b.Join("INNER", joinTable, joinCond)
|
||||
}
|
||||
|
||||
// LeftJoin sets left join SQL
|
||||
func (b *Builder) LeftJoin(joinTable string, joinCond interface{}) *Builder {
|
||||
return b.Join("LEFT", joinTable, joinCond)
|
||||
}
|
||||
|
||||
// RightJoin sets right join SQL
|
||||
func (b *Builder) RightJoin(joinTable string, joinCond interface{}) *Builder {
|
||||
return b.Join("RIGHT", joinTable, joinCond)
|
||||
}
|
||||
|
||||
// CrossJoin sets cross join SQL
|
||||
func (b *Builder) CrossJoin(joinTable string, joinCond interface{}) *Builder {
|
||||
return b.Join("CROSS", joinTable, joinCond)
|
||||
}
|
||||
|
||||
// FullJoin sets full join SQL
|
||||
func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
|
||||
return b.Join("FULL", joinTable, joinCond)
|
||||
}
|
||||
|
||||
// Select sets select SQL
|
||||
func (b *Builder) Select(cols ...string) *Builder {
|
||||
b.selects = cols
|
||||
b.optype = selectType
|
||||
return b
|
||||
}
|
||||
|
||||
// And sets AND condition
|
||||
func (b *Builder) And(cond Cond) *Builder {
|
||||
b.cond = And(b.cond, cond)
|
||||
return b
|
||||
}
|
||||
|
||||
// Or sets OR condition
|
||||
func (b *Builder) Or(cond Cond) *Builder {
|
||||
b.cond = Or(b.cond, cond)
|
||||
return b
|
||||
}
|
||||
|
||||
// Insert sets insert SQL
|
||||
func (b *Builder) Insert(eq Eq) *Builder {
|
||||
b.inserts = eq
|
||||
b.optype = insertType
|
||||
return b
|
||||
}
|
||||
|
||||
// Update sets update SQL
|
||||
func (b *Builder) Update(updates ...Eq) *Builder {
|
||||
b.updates = updates
|
||||
b.optype = updateType
|
||||
return b
|
||||
}
|
||||
|
||||
// Delete sets delete SQL
|
||||
func (b *Builder) Delete(conds ...Cond) *Builder {
|
||||
b.cond = b.cond.And(conds...)
|
||||
b.optype = deleteType
|
||||
return b
|
||||
}
|
||||
|
||||
// WriteTo implements Writer interface
|
||||
func (b *Builder) WriteTo(w Writer) error {
|
||||
switch b.optype {
|
||||
case condType:
|
||||
return b.cond.WriteTo(w)
|
||||
case selectType:
|
||||
return b.selectWriteTo(w)
|
||||
case insertType:
|
||||
return b.insertWriteTo(w)
|
||||
case updateType:
|
||||
return b.updateWriteTo(w)
|
||||
case deleteType:
|
||||
return b.deleteWriteTo(w)
|
||||
}
|
||||
|
||||
return ErrNotSupportType
|
||||
}
|
||||
|
||||
// ToSQL convert a builder to SQL and args
|
||||
func (b *Builder) ToSQL() (string, []interface{}, error) {
|
||||
w := NewWriter()
|
||||
if err := b.WriteTo(w); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return w.writer.String(), w.args, nil
|
||||
}
|
||||
|
||||
// ToSQL convert a builder or condtions to SQL and args
|
||||
func ToSQL(cond interface{}) (string, []interface{}, error) {
|
||||
switch cond.(type) {
|
||||
case Cond:
|
||||
return condToSQL(cond.(Cond))
|
||||
case *Builder:
|
||||
return cond.(*Builder).ToSQL()
|
||||
}
|
||||
return "", nil, ErrNotSupportType
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (b *Builder) deleteWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
return errors.New("no table indicated")
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.cond.WriteTo(w)
|
||||
}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (b *Builder) insertWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
return errors.New("no table indicated")
|
||||
}
|
||||
if len(b.inserts) <= 0 {
|
||||
return errors.New("no column to be update")
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var args = make([]interface{}, 0)
|
||||
var bs []byte
|
||||
var valBuffer = bytes.NewBuffer(bs)
|
||||
var i = 0
|
||||
for col, value := range b.inserts {
|
||||
fmt.Fprint(w, col)
|
||||
if e, ok := value.(expr); ok {
|
||||
fmt.Fprint(valBuffer, e.sql)
|
||||
args = append(args, e.args...)
|
||||
} else {
|
||||
fmt.Fprint(valBuffer, "?")
|
||||
args = append(args, value)
|
||||
}
|
||||
|
||||
if i != len(b.inserts)-1 {
|
||||
if _, err := fmt.Fprint(w, ","); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprint(valBuffer, ","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
i = i + 1
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, ") Values ("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.Write(valBuffer.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprint(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.Append(args...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (b *Builder) selectWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
return errors.New("no table indicated")
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, "SELECT "); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(b.selects) > 0 {
|
||||
for i, s := range b.selects {
|
||||
if _, err := fmt.Fprint(w, s); err != nil {
|
||||
return err
|
||||
}
|
||||
if i != len(b.selects)-1 {
|
||||
if _, err := fmt.Fprint(w, ","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, err := fmt.Fprint(w, "*"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, " FROM %s", b.tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, v := range b.joins {
|
||||
fmt.Fprintf(w, " %s JOIN %s ON ", v.joinType, v.joinTable)
|
||||
if err := v.joinCond.WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.cond.WriteTo(w)
|
||||
}
|
||||
|
|
@ -0,0 +1,197 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MyInt int
|
||||
|
||||
func TestBuilderCond(t *testing.T) {
|
||||
var cases = []struct {
|
||||
cond Cond
|
||||
sql string
|
||||
args []interface{}
|
||||
}{
|
||||
{
|
||||
Eq{"a": 1}.And(Like{"b", "c"}).Or(Eq{"a": 2}.And(Like{"b", "g"})),
|
||||
"(a=? AND b LIKE ?) OR (a=? AND b LIKE ?)",
|
||||
[]interface{}{1, "%c%", 2, "%g%"},
|
||||
},
|
||||
{
|
||||
Eq{"a": 1}.Or(Like{"b", "c"}).And(Eq{"a": 2}.Or(Like{"b", "g"})),
|
||||
"(a=? OR b LIKE ?) AND (a=? OR b LIKE ?)",
|
||||
[]interface{}{1, "%c%", 2, "%g%"},
|
||||
},
|
||||
{
|
||||
Eq{"d": []string{"e", "f"}},
|
||||
"d IN (?,?)",
|
||||
[]interface{}{"e", "f"},
|
||||
},
|
||||
{
|
||||
Neq{"d": []string{"e", "f"}},
|
||||
"d NOT IN (?,?)",
|
||||
[]interface{}{"e", "f"},
|
||||
},
|
||||
{
|
||||
Lt{"d": 3},
|
||||
"d<?",
|
||||
[]interface{}{3},
|
||||
},
|
||||
{
|
||||
Lte{"d": 3},
|
||||
"d<=?",
|
||||
[]interface{}{3},
|
||||
},
|
||||
{
|
||||
Gt{"d": 3},
|
||||
"d>?",
|
||||
[]interface{}{3},
|
||||
},
|
||||
{
|
||||
Gte{"d": 3},
|
||||
"d>=?",
|
||||
[]interface{}{3},
|
||||
},
|
||||
{
|
||||
Between{"d", 0, 2},
|
||||
"d BETWEEN ? AND ?",
|
||||
[]interface{}{0, 2},
|
||||
},
|
||||
{
|
||||
IsNull{"d"},
|
||||
"d IS NULL",
|
||||
[]interface{}{},
|
||||
},
|
||||
{
|
||||
NotIn("a", 1, 2).And(NotIn("b", "c", "d")),
|
||||
"a NOT IN (?,?) AND b NOT IN (?,?)",
|
||||
[]interface{}{1, 2, "c", "d"},
|
||||
},
|
||||
{
|
||||
In("a", 1, 2).Or(In("b", "c", "d")),
|
||||
"a IN (?,?) OR b IN (?,?)",
|
||||
[]interface{}{1, 2, "c", "d"},
|
||||
},
|
||||
{
|
||||
In("a", []int{1, 2}).Or(In("b", []string{"c", "d"})),
|
||||
"a IN (?,?) OR b IN (?,?)",
|
||||
[]interface{}{1, 2, "c", "d"},
|
||||
},
|
||||
{
|
||||
In("a", Expr("select id from x where name > ?", "b")),
|
||||
"a IN (select id from x where name > ?)",
|
||||
[]interface{}{"b"},
|
||||
},
|
||||
{
|
||||
In("a", []MyInt{1, 2}).Or(In("b", []string{"c", "d"})),
|
||||
"a IN (?,?) OR b IN (?,?)",
|
||||
[]interface{}{MyInt(1), MyInt(2), "c", "d"},
|
||||
},
|
||||
{
|
||||
In("a", []int{}),
|
||||
"a IN ()",
|
||||
[]interface{}{},
|
||||
},
|
||||
{
|
||||
NotIn("a", Expr("select id from x where name > ?", "b")),
|
||||
"a NOT IN (select id from x where name > ?)",
|
||||
[]interface{}{"b"},
|
||||
},
|
||||
{
|
||||
NotIn("a", []int{}),
|
||||
"a NOT IN ()",
|
||||
[]interface{}{},
|
||||
},
|
||||
// FIXME: since map will not guarantee the sequence, this may be failed random
|
||||
/*{
|
||||
Or(Eq{"a": 1, "b": 2}, Eq{"c": 3, "d": 4}),
|
||||
"(a=? AND b=?) OR (c=? AND d=?)",
|
||||
[]interface{}{1, 2, 3, 4},
|
||||
},*/
|
||||
}
|
||||
|
||||
for _, k := range cases {
|
||||
sql, args, err := ToSQL(k.cond)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if sql != k.sql {
|
||||
t.Error("want", k.sql, "get", sql)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql)
|
||||
|
||||
if !(len(args) == 0 && len(k.args) == 0) {
|
||||
if !reflect.DeepEqual(args, k.args) {
|
||||
t.Error("want", k.args, "get", args)
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Println(args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuilderSelect(t *testing.T) {
|
||||
sql, args, err := Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
|
||||
sql, args, err = Select("c, d").From("table1").LeftJoin("table2", Eq{"table1.id": 1}.And(Lt{"table2.id": 3})).
|
||||
RightJoin("table3", "table2.id = table3.tid").Where(Eq{"a": 1}).ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
}
|
||||
|
||||
func TestBuilderInsert(t *testing.T) {
|
||||
sql, args, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
}
|
||||
|
||||
func TestBuilderUpdate(t *testing.T) {
|
||||
sql, args, err := Update(Eq{"a": 2}).From("table1").Where(Eq{"a": 1}).ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
|
||||
sql, args, err = Update(Eq{"a": 2, "b": Incr(1)}).From("table2").Where(Eq{"a": 1}).ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
|
||||
sql, args, err = Update(Eq{"a": 2, "b": Incr(1), "c": Decr(1), "d": Expr("select count(*) from table2")}).From("table2").Where(Eq{"a": 1}).ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
}
|
||||
|
||||
func TestBuilderDelete(t *testing.T) {
|
||||
sql, args, err := Delete(Eq{"a": 1}).From("table1").ToSQL()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
fmt.Println(sql, args)
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (b *Builder) updateWriteTo(w Writer) error {
|
||||
if len(b.tableName) <= 0 {
|
||||
return errors.New("no table indicated")
|
||||
}
|
||||
if len(b.updates) <= 0 {
|
||||
return errors.New("no column to be update")
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, "UPDATE %s SET ", b.tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, s := range b.updates {
|
||||
if err := s.opWriteTo(",", w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if i != len(b.updates)-1 {
|
||||
if _, err := fmt.Fprint(w, ","); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprint(w, " WHERE "); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return b.cond.WriteTo(w)
|
||||
}
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
dependencies:
|
||||
override:
|
||||
# './...' is a relative pattern which means all subdirectories
|
||||
- go get -t -d -v ./...
|
||||
- go build -v
|
||||
- go get -u github.com/golang/lint/golint
|
||||
|
||||
test:
|
||||
override:
|
||||
# './...' is a relative pattern which means all subdirectories
|
||||
- golint ./...
|
||||
- go test -v -race
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Writer defines the interface
|
||||
type Writer interface {
|
||||
io.Writer
|
||||
Append(...interface{})
|
||||
}
|
||||
|
||||
var _ Writer = NewWriter()
|
||||
|
||||
// BytesWriter implments Writer and save SQL in bytes.Buffer
|
||||
type BytesWriter struct {
|
||||
writer *bytes.Buffer
|
||||
buffer []byte
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// NewWriter creates a new string writer
|
||||
func NewWriter() *BytesWriter {
|
||||
w := &BytesWriter{}
|
||||
w.writer = bytes.NewBuffer(w.buffer)
|
||||
return w
|
||||
}
|
||||
|
||||
// Write writes data to Writer
|
||||
func (s *BytesWriter) Write(buf []byte) (int, error) {
|
||||
return s.writer.Write(buf)
|
||||
}
|
||||
|
||||
// Append appends args to Writer
|
||||
func (s *BytesWriter) Append(args ...interface{}) {
|
||||
s.args = append(s.args, args...)
|
||||
}
|
||||
|
||||
// Cond defines an interface
|
||||
type Cond interface {
|
||||
WriteTo(Writer) error
|
||||
And(...Cond) Cond
|
||||
Or(...Cond) Cond
|
||||
IsValid() bool
|
||||
}
|
||||
|
||||
type condEmpty struct{}
|
||||
|
||||
var _ Cond = condEmpty{}
|
||||
|
||||
// NewCond creates an empty condition
|
||||
func NewCond() Cond {
|
||||
return condEmpty{}
|
||||
}
|
||||
|
||||
func (condEmpty) WriteTo(w Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (condEmpty) And(conds ...Cond) Cond {
|
||||
return And(conds...)
|
||||
}
|
||||
|
||||
func (condEmpty) Or(conds ...Cond) Cond {
|
||||
return Or(conds...)
|
||||
}
|
||||
|
||||
func (condEmpty) IsValid() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func condToSQL(cond Cond) (string, []interface{}, error) {
|
||||
if cond == nil || !cond.IsValid() {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
w := NewWriter()
|
||||
if err := cond.WriteTo(w); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return w.writer.String(), w.args, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
type condAnd []Cond
|
||||
|
||||
var _ Cond = condAnd{}
|
||||
|
||||
// And generates AND conditions
|
||||
func And(conds ...Cond) Cond {
|
||||
var result = make(condAnd, 0, len(conds))
|
||||
for _, cond := range conds {
|
||||
if cond == nil || !cond.IsValid() {
|
||||
continue
|
||||
}
|
||||
result = append(result, cond)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (and condAnd) WriteTo(w Writer) error {
|
||||
for i, cond := range and {
|
||||
_, isOr := cond.(condOr)
|
||||
if isOr {
|
||||
fmt.Fprint(w, "(")
|
||||
}
|
||||
|
||||
err := cond.WriteTo(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isOr {
|
||||
fmt.Fprint(w, ")")
|
||||
}
|
||||
|
||||
if i != len(and)-1 {
|
||||
fmt.Fprint(w, " AND ")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (and condAnd) And(conds ...Cond) Cond {
|
||||
return And(and, And(conds...))
|
||||
}
|
||||
|
||||
func (and condAnd) Or(conds ...Cond) Cond {
|
||||
return Or(and, Or(conds...))
|
||||
}
|
||||
|
||||
func (and condAnd) IsValid() bool {
|
||||
return len(and) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Between implmentes between condition
|
||||
type Between struct {
|
||||
Col string
|
||||
LessVal interface{}
|
||||
MoreVal interface{}
|
||||
}
|
||||
|
||||
var _ Cond = Between{}
|
||||
|
||||
// WriteTo write data to Writer
|
||||
func (between Between) WriteTo(w Writer) error {
|
||||
if _, err := fmt.Fprintf(w, "%s BETWEEN ? AND ?", between.Col); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(between.LessVal, between.MoreVal)
|
||||
return nil
|
||||
}
|
||||
|
||||
// And implments And with other conditions
|
||||
func (between Between) And(conds ...Cond) Cond {
|
||||
return And(between, And(conds...))
|
||||
}
|
||||
|
||||
// Or implments Or with other conditions
|
||||
func (between Between) Or(conds ...Cond) Cond {
|
||||
return Or(between, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if the condition is valid
|
||||
func (between Between) IsValid() bool {
|
||||
return len(between.Col) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// WriteMap writes conditions' SQL to Writer, op could be =, <>, >, <, <=, >= and etc.
|
||||
func WriteMap(w Writer, data map[string]interface{}, op string) error {
|
||||
var args = make([]interface{}, 0, len(data))
|
||||
var i = 0
|
||||
for k, v := range data {
|
||||
switch v.(type) {
|
||||
case expr:
|
||||
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.(expr).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Builder:
|
||||
if _, err := fmt.Fprintf(w, "%s%s(", k, op); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.(*Builder).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if _, err := fmt.Fprintf(w, "%s%s?", k, op); err != nil {
|
||||
return err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
if i != len(data)-1 {
|
||||
if _, err := fmt.Fprint(w, " AND "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
i = i + 1
|
||||
}
|
||||
w.Append(args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lt defines < condition
|
||||
type Lt map[string]interface{}
|
||||
|
||||
var _ Cond = Lt{}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (lt Lt) WriteTo(w Writer) error {
|
||||
return WriteMap(w, lt, "<")
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (lt Lt) And(conds ...Cond) Cond {
|
||||
return condAnd{lt, And(conds...)}
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (lt Lt) Or(conds ...Cond) Cond {
|
||||
return condOr{lt, Or(conds...)}
|
||||
}
|
||||
|
||||
// IsValid tests if this Eq is valid
|
||||
func (lt Lt) IsValid() bool {
|
||||
return len(lt) > 0
|
||||
}
|
||||
|
||||
// Lte defines <= condition
|
||||
type Lte map[string]interface{}
|
||||
|
||||
var _ Cond = Lte{}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (lte Lte) WriteTo(w Writer) error {
|
||||
return WriteMap(w, lte, "<=")
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (lte Lte) And(conds ...Cond) Cond {
|
||||
return And(lte, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (lte Lte) Or(conds ...Cond) Cond {
|
||||
return Or(lte, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this Eq is valid
|
||||
func (lte Lte) IsValid() bool {
|
||||
return len(lte) > 0
|
||||
}
|
||||
|
||||
// Gt defines > condition
|
||||
type Gt map[string]interface{}
|
||||
|
||||
var _ Cond = Gt{}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (gt Gt) WriteTo(w Writer) error {
|
||||
return WriteMap(w, gt, ">")
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (gt Gt) And(conds ...Cond) Cond {
|
||||
return And(gt, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (gt Gt) Or(conds ...Cond) Cond {
|
||||
return Or(gt, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this Eq is valid
|
||||
func (gt Gt) IsValid() bool {
|
||||
return len(gt) > 0
|
||||
}
|
||||
|
||||
// Gte defines >= condition
|
||||
type Gte map[string]interface{}
|
||||
|
||||
var _ Cond = Gte{}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (gte Gte) WriteTo(w Writer) error {
|
||||
return WriteMap(w, gte, ">=")
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (gte Gte) And(conds ...Cond) Cond {
|
||||
return And(gte, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (gte Gte) Or(conds ...Cond) Cond {
|
||||
return Or(gte, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this Eq is valid
|
||||
func (gte Gte) IsValid() bool {
|
||||
return len(gte) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Incr implements a type used by Eq
|
||||
type Incr int
|
||||
|
||||
// Decr implements a type used by Eq
|
||||
type Decr int
|
||||
|
||||
// Eq defines equals conditions
|
||||
type Eq map[string]interface{}
|
||||
|
||||
var _ Cond = Eq{}
|
||||
|
||||
func (eq Eq) opWriteTo(op string, w Writer) error {
|
||||
var i = 0
|
||||
for k, v := range eq {
|
||||
switch v.(type) {
|
||||
case []int, []int64, []string, []int32, []int16, []int8, []uint, []uint64, []uint32, []uint16, []interface{}:
|
||||
if err := In(k, v).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
case expr:
|
||||
if _, err := fmt.Fprintf(w, "%s=(", k); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.(expr).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Builder:
|
||||
if _, err := fmt.Fprintf(w, "%s=(", k); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.(*Builder).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case Incr:
|
||||
if _, err := fmt.Fprintf(w, "%s=%s+?", k, k); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(int(v.(Incr)))
|
||||
case Decr:
|
||||
if _, err := fmt.Fprintf(w, "%s=%s-?", k, k); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(int(v.(Decr)))
|
||||
default:
|
||||
if _, err := fmt.Fprintf(w, "%s=?", k); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(v)
|
||||
}
|
||||
if i != len(eq)-1 {
|
||||
if _, err := fmt.Fprint(w, op); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
i = i + 1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteTo writes SQL to Writer
|
||||
func (eq Eq) WriteTo(w Writer) error {
|
||||
return eq.opWriteTo(" AND ", w)
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (eq Eq) And(conds ...Cond) Cond {
|
||||
return And(eq, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (eq Eq) Or(conds ...Cond) Cond {
|
||||
return Or(eq, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this Eq is valid
|
||||
func (eq Eq) IsValid() bool {
|
||||
return len(eq) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
type expr struct {
|
||||
sql string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
var _ Cond = expr{}
|
||||
|
||||
// Expr generate customerize SQL
|
||||
func Expr(sql string, args ...interface{}) Cond {
|
||||
return expr{sql, args}
|
||||
}
|
||||
|
||||
func (expr expr) WriteTo(w Writer) error {
|
||||
if _, err := fmt.Fprint(w, expr.sql); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(expr.args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (expr expr) And(conds ...Cond) Cond {
|
||||
return And(expr, And(conds...))
|
||||
}
|
||||
|
||||
func (expr expr) Or(conds ...Cond) Cond {
|
||||
return Or(expr, Or(conds...))
|
||||
}
|
||||
|
||||
func (expr expr) IsValid() bool {
|
||||
return len(expr.sql) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type condIn struct {
|
||||
col string
|
||||
vals []interface{}
|
||||
}
|
||||
|
||||
var _ Cond = condIn{}
|
||||
|
||||
// In generates IN condition
|
||||
func In(col string, values ...interface{}) Cond {
|
||||
return condIn{col, values}
|
||||
}
|
||||
|
||||
func (condIn condIn) handleBlank(w Writer) error {
|
||||
if _, err := fmt.Fprintf(w, "%s IN ()", condIn.col); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (condIn condIn) WriteTo(w Writer) error {
|
||||
if len(condIn.vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
|
||||
switch condIn.vals[0].(type) {
|
||||
case []int8:
|
||||
vals := condIn.vals[0].([]int8)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int16:
|
||||
vals := condIn.vals[0].([]int16)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int:
|
||||
vals := condIn.vals[0].([]int)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int32:
|
||||
vals := condIn.vals[0].([]int32)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int64:
|
||||
vals := condIn.vals[0].([]int64)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint8:
|
||||
vals := condIn.vals[0].([]uint8)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint16:
|
||||
vals := condIn.vals[0].([]uint16)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint:
|
||||
vals := condIn.vals[0].([]uint)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint32:
|
||||
vals := condIn.vals[0].([]uint32)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint64:
|
||||
vals := condIn.vals[0].([]uint64)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []string:
|
||||
vals := condIn.vals[0].([]string)
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []interface{}:
|
||||
vals := condIn.vals[0].([]interface{})
|
||||
if len(vals) <= 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(vals...)
|
||||
case expr:
|
||||
val := condIn.vals[0].(expr)
|
||||
if _, err := fmt.Fprintf(w, "%s IN (", condIn.col); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := val.WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Builder:
|
||||
bd := condIn.vals[0].(*Builder)
|
||||
if _, err := fmt.Fprintf(w, "%s IN (", condIn.col); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bd.WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
v := reflect.ValueOf(condIn.vals[0])
|
||||
if v.Kind() == reflect.Slice {
|
||||
l := v.Len()
|
||||
if l == 0 {
|
||||
return condIn.handleBlank(w)
|
||||
}
|
||||
|
||||
questionMark := strings.Repeat("?,", l)
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < l; i++ {
|
||||
w.Append(v.Index(i).Interface())
|
||||
}
|
||||
} else {
|
||||
questionMark := strings.Repeat("?,", len(condIn.vals))
|
||||
if _, err := fmt.Fprintf(w, "%s IN (%s)", condIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(condIn.vals...)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (condIn condIn) And(conds ...Cond) Cond {
|
||||
return And(condIn, And(conds...))
|
||||
}
|
||||
|
||||
func (condIn condIn) Or(conds ...Cond) Cond {
|
||||
return Or(condIn, Or(conds...))
|
||||
}
|
||||
|
||||
func (condIn condIn) IsValid() bool {
|
||||
return len(condIn.col) > 0 && len(condIn.vals) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Like defines like condition
|
||||
type Like [2]string
|
||||
|
||||
var _ Cond = Like{"", ""}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (like Like) WriteTo(w Writer) error {
|
||||
if _, err := fmt.Fprintf(w, "%s LIKE ?", like[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
// FIXME: if use other regular express, this will be failed. but for compitable, keep this
|
||||
if like[1][0] == '%' || like[1][len(like[1])-1] == '%' {
|
||||
w.Append(like[1])
|
||||
} else {
|
||||
w.Append("%" + like[1] + "%")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (like Like) And(conds ...Cond) Cond {
|
||||
return And(like, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (like Like) Or(conds ...Cond) Cond {
|
||||
return Or(like, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this condition is valid
|
||||
func (like Like) IsValid() bool {
|
||||
return len(like[0]) > 0 && len(like[1]) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Neq defines not equal conditions
|
||||
type Neq map[string]interface{}
|
||||
|
||||
var _ Cond = Neq{}
|
||||
|
||||
// WriteTo writes SQL to Writer
|
||||
func (neq Neq) WriteTo(w Writer) error {
|
||||
var args = make([]interface{}, 0, len(neq))
|
||||
var i = 0
|
||||
for k, v := range neq {
|
||||
switch v.(type) {
|
||||
case []int, []int64, []string, []int32, []int16, []int8:
|
||||
if err := NotIn(k, v).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
case expr:
|
||||
if _, err := fmt.Fprintf(w, "%s<>(", k); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.(expr).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Builder:
|
||||
if _, err := fmt.Fprintf(w, "%s<>(", k); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.(*Builder).WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
if _, err := fmt.Fprintf(w, "%s<>?", k); err != nil {
|
||||
return err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
if i != len(neq)-1 {
|
||||
if _, err := fmt.Fprint(w, " AND "); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
i = i + 1
|
||||
}
|
||||
w.Append(args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (neq Neq) And(conds ...Cond) Cond {
|
||||
return And(neq, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (neq Neq) Or(conds ...Cond) Cond {
|
||||
return Or(neq, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this condition is valid
|
||||
func (neq Neq) IsValid() bool {
|
||||
return len(neq) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Not defines NOT condition
|
||||
type Not [1]Cond
|
||||
|
||||
var _ Cond = Not{}
|
||||
|
||||
// WriteTo writes SQL to Writer
|
||||
func (not Not) WriteTo(w Writer) error {
|
||||
if _, err := fmt.Fprint(w, "NOT "); err != nil {
|
||||
return err
|
||||
}
|
||||
switch not[0].(type) {
|
||||
case condAnd, condOr:
|
||||
if _, err := fmt.Fprint(w, "("); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := not[0].WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch not[0].(type) {
|
||||
case condAnd, condOr:
|
||||
if _, err := fmt.Fprint(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (not Not) And(conds ...Cond) Cond {
|
||||
return And(not, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (not Not) Or(conds ...Cond) Cond {
|
||||
return Or(not, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this condition is valid
|
||||
func (not Not) IsValid() bool {
|
||||
return not[0] != nil && not[0].IsValid()
|
||||
}
|
||||
|
|
@ -0,0 +1,236 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type condNotIn condIn
|
||||
|
||||
var _ Cond = condNotIn{}
|
||||
|
||||
// NotIn generate NOT IN condition
|
||||
func NotIn(col string, values ...interface{}) Cond {
|
||||
return condNotIn{col, values}
|
||||
}
|
||||
|
||||
func (condNotIn condNotIn) handleBlank(w Writer) error {
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN ()", condNotIn.col); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (condNotIn condNotIn) WriteTo(w Writer) error {
|
||||
if len(condNotIn.vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
|
||||
switch condNotIn.vals[0].(type) {
|
||||
case []int8:
|
||||
vals := condNotIn.vals[0].([]int8)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int16:
|
||||
vals := condNotIn.vals[0].([]int16)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int:
|
||||
vals := condNotIn.vals[0].([]int)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int32:
|
||||
vals := condNotIn.vals[0].([]int32)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []int64:
|
||||
vals := condNotIn.vals[0].([]int64)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint8:
|
||||
vals := condNotIn.vals[0].([]uint8)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint16:
|
||||
vals := condNotIn.vals[0].([]uint16)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint:
|
||||
vals := condNotIn.vals[0].([]uint)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint32:
|
||||
vals := condNotIn.vals[0].([]uint32)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []uint64:
|
||||
vals := condNotIn.vals[0].([]uint64)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []string:
|
||||
vals := condNotIn.vals[0].([]string)
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, val := range vals {
|
||||
w.Append(val)
|
||||
}
|
||||
case []interface{}:
|
||||
vals := condNotIn.vals[0].([]interface{})
|
||||
if len(vals) <= 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
questionMark := strings.Repeat("?,", len(vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(vals...)
|
||||
case expr:
|
||||
val := condNotIn.vals[0].(expr)
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (", condNotIn.col); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := val.WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
case *Builder:
|
||||
val := condNotIn.vals[0].(*Builder)
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (", condNotIn.col); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := val.WriteTo(w); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := fmt.Fprintf(w, ")"); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
v := reflect.ValueOf(condNotIn.vals[0])
|
||||
if v.Kind() == reflect.Slice {
|
||||
l := v.Len()
|
||||
if l == 0 {
|
||||
return condNotIn.handleBlank(w)
|
||||
}
|
||||
|
||||
questionMark := strings.Repeat("?,", l)
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < l; i++ {
|
||||
w.Append(v.Index(i).Interface())
|
||||
}
|
||||
} else {
|
||||
questionMark := strings.Repeat("?,", len(condNotIn.vals))
|
||||
if _, err := fmt.Fprintf(w, "%s NOT IN (%s)", condNotIn.col, questionMark[:len(questionMark)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Append(condNotIn.vals...)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (condNotIn condNotIn) And(conds ...Cond) Cond {
|
||||
return And(condNotIn, And(conds...))
|
||||
}
|
||||
|
||||
func (condNotIn condNotIn) Or(conds ...Cond) Cond {
|
||||
return Or(condNotIn, Or(conds...))
|
||||
}
|
||||
|
||||
func (condNotIn condNotIn) IsValid() bool {
|
||||
return len(condNotIn.col) > 0 && len(condNotIn.vals) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
// IsNull defines IS NULL condition
|
||||
type IsNull [1]string
|
||||
|
||||
var _ Cond = IsNull{""}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (isNull IsNull) WriteTo(w Writer) error {
|
||||
_, err := fmt.Fprintf(w, "%s IS NULL", isNull[0])
|
||||
return err
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (isNull IsNull) And(conds ...Cond) Cond {
|
||||
return And(isNull, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (isNull IsNull) Or(conds ...Cond) Cond {
|
||||
return Or(isNull, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this condition is valid
|
||||
func (isNull IsNull) IsValid() bool {
|
||||
return len(isNull[0]) > 0
|
||||
}
|
||||
|
||||
// NotNull defines NOT NULL condition
|
||||
type NotNull [1]string
|
||||
|
||||
var _ Cond = NotNull{""}
|
||||
|
||||
// WriteTo write SQL to Writer
|
||||
func (notNull NotNull) WriteTo(w Writer) error {
|
||||
_, err := fmt.Fprintf(w, "%s IS NOT NULL", notNull[0])
|
||||
return err
|
||||
}
|
||||
|
||||
// And implements And with other conditions
|
||||
func (notNull NotNull) And(conds ...Cond) Cond {
|
||||
return And(notNull, And(conds...))
|
||||
}
|
||||
|
||||
// Or implements Or with other conditions
|
||||
func (notNull NotNull) Or(conds ...Cond) Cond {
|
||||
return Or(notNull, Or(conds...))
|
||||
}
|
||||
|
||||
// IsValid tests if this condition is valid
|
||||
func (notNull NotNull) IsValid() bool {
|
||||
return len(notNull[0]) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "fmt"
|
||||
|
||||
type condOr []Cond
|
||||
|
||||
var _ Cond = condOr{}
|
||||
|
||||
// Or sets OR conditions
|
||||
func Or(conds ...Cond) Cond {
|
||||
var result = make(condOr, 0, len(conds))
|
||||
for _, cond := range conds {
|
||||
if cond == nil || !cond.IsValid() {
|
||||
continue
|
||||
}
|
||||
result = append(result, cond)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// WriteTo implments Cond
|
||||
func (o condOr) WriteTo(w Writer) error {
|
||||
for i, cond := range o {
|
||||
var needQuote bool
|
||||
switch cond.(type) {
|
||||
case condAnd:
|
||||
needQuote = true
|
||||
case Eq:
|
||||
needQuote = (len(cond.(Eq)) > 1)
|
||||
}
|
||||
|
||||
if needQuote {
|
||||
fmt.Fprint(w, "(")
|
||||
}
|
||||
|
||||
err := cond.WriteTo(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if needQuote {
|
||||
fmt.Fprint(w, ")")
|
||||
}
|
||||
|
||||
if i != len(o)-1 {
|
||||
fmt.Fprint(w, " OR ")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o condOr) And(conds ...Cond) Cond {
|
||||
return And(o, And(conds...))
|
||||
}
|
||||
|
||||
func (o condOr) Or(conds ...Cond) Cond {
|
||||
return Or(o, Or(conds...))
|
||||
}
|
||||
|
||||
func (o condOr) IsValid() bool {
|
||||
return len(o) > 0
|
||||
}
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
// Copyright 2016 The XORM Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
|
||||
Package builder is a simple and powerful sql builder for Go.
|
||||
|
||||
Make sure you have installed Go 1.1+ and then:
|
||||
|
||||
go get github.com/go-xorm/builder
|
||||
|
||||
WARNNING: Currently, only query conditions are supported. Below is the supported conditions.
|
||||
|
||||
1. Eq is a redefine of a map, you can give one or more conditions to Eq
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Eq{"a":1})
|
||||
// a=? [1]
|
||||
sql, args, _ := ToSQL(Eq{"b":"c"}.And(Eq{"c": 0}))
|
||||
// b=? AND c=? ["c", 0]
|
||||
sql, args, _ := ToSQL(Eq{"b":"c", "c":0})
|
||||
// b=? AND c=? ["c", 0]
|
||||
sql, args, _ := ToSQL(Eq{"b":"c"}.Or(Eq{"b":"d"}))
|
||||
// b=? OR b=? ["c", "d"]
|
||||
sql, args, _ := ToSQL(Eq{"b": []string{"c", "d"}})
|
||||
// b IN (?,?) ["c", "d"]
|
||||
sql, args, _ := ToSQL(Eq{"b": 1, "c":[]int{2, 3}})
|
||||
// b=? AND c IN (?,?) [1, 2, 3]
|
||||
|
||||
2. Neq is the same to Eq
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Neq{"a":1})
|
||||
// a<>? [1]
|
||||
sql, args, _ := ToSQL(Neq{"b":"c"}.And(Neq{"c": 0}))
|
||||
// b<>? AND c<>? ["c", 0]
|
||||
sql, args, _ := ToSQL(Neq{"b":"c", "c":0})
|
||||
// b<>? AND c<>? ["c", 0]
|
||||
sql, args, _ := ToSQL(Neq{"b":"c"}.Or(Neq{"b":"d"}))
|
||||
// b<>? OR b<>? ["c", "d"]
|
||||
sql, args, _ := ToSQL(Neq{"b": []string{"c", "d"}})
|
||||
// b NOT IN (?,?) ["c", "d"]
|
||||
sql, args, _ := ToSQL(Neq{"b": 1, "c":[]int{2, 3}})
|
||||
// b<>? AND c NOT IN (?,?) [1, 2, 3]
|
||||
|
||||
3. Gt, Gte, Lt, Lte
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Gt{"a", 1}.And(Gte{"b", 2}))
|
||||
// a>? AND b>=? [1, 2]
|
||||
sql, args, _ := ToSQL(Lt{"a", 1}.Or(Lte{"b", 2}))
|
||||
// a<? OR b<=? [1, 2]
|
||||
|
||||
4. Like
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Like{"a", "c"})
|
||||
// a LIKE ? [%c%]
|
||||
|
||||
5. Expr you can customerize your sql with Expr
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Expr("a = ? ", 1))
|
||||
// a = ? [1]
|
||||
sql, args, _ := ToSQL(Eq{"a": Expr("select id from table where c = ?", 1)})
|
||||
// a=(select id from table where c = ?) [1]
|
||||
|
||||
6. In and NotIn
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(In("a", 1, 2, 3))
|
||||
// a IN (?,?,?) [1,2,3]
|
||||
sql, args, _ := ToSQL(In("a", []int{1, 2, 3}))
|
||||
// a IN (?,?,?) [1,2,3]
|
||||
sql, args, _ := ToSQL(In("a", Expr("select id from b where c = ?", 1))))
|
||||
// a IN (select id from b where c = ?) [1]
|
||||
|
||||
7. IsNull and NotNull
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(IsNull{"a"})
|
||||
// a IS NULL []
|
||||
sql, args, _ := ToSQL(NotNull{"b"})
|
||||
// b IS NOT NULL []
|
||||
|
||||
8. And(conds ...Cond), And can connect one or more condtions via AND
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(And(Eq{"a":1}, Like{"b", "c"}, Neq{"d", 2}))
|
||||
// a=? AND b LIKE ? AND d<>? [1, %c%, 2]
|
||||
|
||||
9. Or(conds ...Cond), Or can connect one or more conditions via Or
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Or(Eq{"a":1}, Like{"b", "c"}, Neq{"d", 2}))
|
||||
// a=? OR b LIKE ? OR d<>? [1, %c%, 2]
|
||||
sql, args, _ := ToSQL(Or(Eq{"a":1}, And(Like{"b", "c"}, Neq{"d", 2})))
|
||||
// a=? OR (b LIKE ? AND d<>?) [1, %c%, 2]
|
||||
|
||||
10. Between
|
||||
|
||||
import . "github.com/go-xorm/builder"
|
||||
|
||||
sql, args, _ := ToSQL(Between("a", 1, 2))
|
||||
// a BETWEEN 1 AND 2
|
||||
|
||||
11. define yourself conditions
|
||||
Since Cond is a interface, you can define yourself conditions and compare with them
|
||||
*/
|
||||
package builder
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright 2016 The Xorm Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package builder
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNotSupportType not supported SQL type error
|
||||
ErrNotSupportType = errors.New("not supported SQL type")
|
||||
// ErrNoNotInConditions no NOT IN params error
|
||||
ErrNoNotInConditions = errors.New("No NOT IN conditions")
|
||||
// ErrNoInConditions no IN params error
|
||||
ErrNoInConditions = errors.New("No IN conditions")
|
||||
)
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
Copyright (c) 2013 - 2015 Lunny Xiao <xiaolunwen@gmail.com>
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the {organization} nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
Core is a lightweight wrapper of sql.DB.
|
||||
|
||||
[](https://circleci.com/gh/go-xorm/core/tree/master)
|
||||
|
||||
# Open
|
||||
```Go
|
||||
db, _ := core.Open(db, connstr)
|
||||
```
|
||||
|
||||
# SetMapper
|
||||
```Go
|
||||
db.SetMapper(SameMapper())
|
||||
```
|
||||
|
||||
## Scan usage
|
||||
|
||||
### Scan
|
||||
```Go
|
||||
rows, _ := db.Query()
|
||||
for rows.Next() {
|
||||
rows.Scan()
|
||||
}
|
||||
```
|
||||
|
||||
### ScanMap
|
||||
```Go
|
||||
rows, _ := db.Query()
|
||||
for rows.Next() {
|
||||
rows.ScanMap()
|
||||
```
|
||||
|
||||
### ScanSlice
|
||||
|
||||
You can use `[]string`, `[][]byte`, `[]interface{}`, `[]*string`, `[]sql.NullString` to ScanSclice. Notice, slice's length should be equal or less than select columns.
|
||||
|
||||
```Go
|
||||
rows, _ := db.Query()
|
||||
cols, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
var s = make([]string, len(cols))
|
||||
rows.ScanSlice(&s)
|
||||
}
|
||||
```
|
||||
|
||||
```Go
|
||||
rows, _ := db.Query()
|
||||
cols, _ := rows.Columns()
|
||||
for rows.Next() {
|
||||
var s = make([]*string, len(cols))
|
||||
rows.ScanSlice(&s)
|
||||
}
|
||||
```
|
||||
|
||||
### ScanStruct
|
||||
```Go
|
||||
rows, _ := db.Query()
|
||||
for rows.Next() {
|
||||
rows.ScanStructByName()
|
||||
rows.ScanStructByIndex()
|
||||
}
|
||||
```
|
||||
|
||||
## Query usage
|
||||
```Go
|
||||
rows, err := db.Query("select * from table where name = ?", name)
|
||||
|
||||
user = User{
|
||||
Name:"lunny",
|
||||
}
|
||||
rows, err := db.QueryStruct("select * from table where name = ?Name",
|
||||
&user)
|
||||
|
||||
var user = map[string]interface{}{
|
||||
"name": "lunny",
|
||||
}
|
||||
rows, err = db.QueryMap("select * from table where name = ?name",
|
||||
&user)
|
||||
```
|
||||
|
||||
## QueryRow usage
|
||||
```Go
|
||||
row := db.QueryRow("select * from table where name = ?", name)
|
||||
|
||||
user = User{
|
||||
Name:"lunny",
|
||||
}
|
||||
row := db.QueryRowStruct("select * from table where name = ?Name",
|
||||
&user)
|
||||
|
||||
var user = map[string]interface{}{
|
||||
"name": "lunny",
|
||||
}
|
||||
row = db.QueryRowMap("select * from table where name = ?name",
|
||||
&user)
|
||||
```
|
||||
|
||||
## Exec usage
|
||||
```Go
|
||||
db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)", name, title, age, alias...)
|
||||
|
||||
user = User{
|
||||
Name:"lunny",
|
||||
Title:"test",
|
||||
Age: 18,
|
||||
}
|
||||
result, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
|
||||
&user)
|
||||
|
||||
var user = map[string]interface{}{
|
||||
"Name": "lunny",
|
||||
"Title": "test",
|
||||
"Age": 18,
|
||||
}
|
||||
result, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
|
||||
&user)
|
||||
```
|
||||
|
|
@ -0,0 +1 @@
|
|||
go test -v -bench=. -run=XXX
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
)
|
||||
|
||||
const (
|
||||
// default cache expired time
|
||||
CacheExpired = 60 * time.Minute
|
||||
// not use now
|
||||
CacheMaxMemory = 256
|
||||
// evey ten minutes to clear all expired nodes
|
||||
CacheGcInterval = 10 * time.Minute
|
||||
// each time when gc to removed max nodes
|
||||
CacheGcMaxRemoved = 20
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCacheMiss = errors.New("xorm/cache: key not found.")
|
||||
ErrNotStored = errors.New("xorm/cache: not stored.")
|
||||
)
|
||||
|
||||
// CacheStore is a interface to store cache
|
||||
type CacheStore interface {
|
||||
// key is primary key or composite primary key
|
||||
// value is struct's pointer
|
||||
// key format : <tablename>-p-<pk1>-<pk2>...
|
||||
Put(key string, value interface{}) error
|
||||
Get(key string) (interface{}, error)
|
||||
Del(key string) error
|
||||
}
|
||||
|
||||
// Cacher is an interface to provide cache
|
||||
// id format : u-<pk1>-<pk2>...
|
||||
type Cacher interface {
|
||||
GetIds(tableName, sql string) interface{}
|
||||
GetBean(tableName string, id string) interface{}
|
||||
PutIds(tableName, sql string, ids interface{})
|
||||
PutBean(tableName string, id string, obj interface{})
|
||||
DelIds(tableName, sql string)
|
||||
DelBean(tableName string, id string)
|
||||
ClearIds(tableName string)
|
||||
ClearBeans(tableName string)
|
||||
}
|
||||
|
||||
func encodeIds(ids []PK) (string, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(ids)
|
||||
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
|
||||
func decodeIds(s string) ([]PK, error) {
|
||||
pks := make([]PK, 0)
|
||||
|
||||
dec := gob.NewDecoder(bytes.NewBufferString(s))
|
||||
err := dec.Decode(&pks)
|
||||
|
||||
return pks, err
|
||||
}
|
||||
|
||||
func GetCacheSql(m Cacher, tableName, sql string, args interface{}) ([]PK, error) {
|
||||
bytes := m.GetIds(tableName, GenSqlKey(sql, args))
|
||||
if bytes == nil {
|
||||
return nil, errors.New("Not Exist")
|
||||
}
|
||||
return decodeIds(bytes.(string))
|
||||
}
|
||||
|
||||
func PutCacheSql(m Cacher, ids []PK, tableName, sql string, args interface{}) error {
|
||||
bytes, err := encodeIds(ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.PutIds(tableName, GenSqlKey(sql, args), bytes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenSqlKey(sql string, args interface{}) string {
|
||||
return fmt.Sprintf("%v-%v", sql, args)
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
dependencies:
|
||||
override:
|
||||
# './...' is a relative pattern which means all subdirectories
|
||||
- go get -t -d -v ./...
|
||||
- go build -v
|
||||
|
||||
database:
|
||||
override:
|
||||
- mysql -u root -e "CREATE DATABASE core_test DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci"
|
||||
|
||||
test:
|
||||
override:
|
||||
# './...' is a relative pattern which means all subdirectories
|
||||
- go test -v -race
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
TWOSIDES = iota + 1
|
||||
ONLYTODB
|
||||
ONLYFROMDB
|
||||
)
|
||||
|
||||
type AssociateType int
|
||||
|
||||
const (
|
||||
AssociateNone AssociateType = iota
|
||||
AssociateBelongsTo
|
||||
)
|
||||
|
||||
// database column
|
||||
type Column struct {
|
||||
Name string
|
||||
TableName string
|
||||
FieldName string
|
||||
SQLType SQLType
|
||||
Length int
|
||||
Length2 int
|
||||
Nullable bool
|
||||
Default string
|
||||
Indexes map[string]int
|
||||
IsPrimaryKey bool
|
||||
IsAutoIncrement bool
|
||||
MapType int
|
||||
IsCreated bool
|
||||
IsUpdated bool
|
||||
IsDeleted bool
|
||||
IsCascade bool
|
||||
IsVersion bool
|
||||
DefaultIsEmpty bool
|
||||
EnumOptions map[string]int
|
||||
SetOptions map[string]int
|
||||
DisableTimeZone bool
|
||||
TimeZone *time.Location // column specified time zone
|
||||
AssociateType
|
||||
AssociateTable *Table
|
||||
}
|
||||
|
||||
func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column {
|
||||
return &Column{
|
||||
Name: name,
|
||||
TableName: "",
|
||||
FieldName: fieldName,
|
||||
SQLType: sqlType,
|
||||
Length: len1,
|
||||
Length2: len2,
|
||||
Nullable: nullable,
|
||||
Default: "",
|
||||
Indexes: make(map[string]int),
|
||||
IsPrimaryKey: false,
|
||||
IsAutoIncrement: false,
|
||||
MapType: TWOSIDES,
|
||||
IsCreated: false,
|
||||
IsUpdated: false,
|
||||
IsDeleted: false,
|
||||
IsCascade: false,
|
||||
IsVersion: false,
|
||||
DefaultIsEmpty: false,
|
||||
EnumOptions: make(map[string]int),
|
||||
AssociateType: AssociateNone,
|
||||
AssociateTable: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// generate column description string according dialect
|
||||
func (col *Column) String(d Dialect) string {
|
||||
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
|
||||
|
||||
sql += d.SqlType(col) + " "
|
||||
|
||||
if col.IsPrimaryKey {
|
||||
sql += "PRIMARY KEY "
|
||||
if col.IsAutoIncrement {
|
||||
sql += d.AutoIncrStr() + " "
|
||||
}
|
||||
}
|
||||
|
||||
if d.ShowCreateNull() {
|
||||
if col.Nullable {
|
||||
sql += "NULL "
|
||||
} else {
|
||||
sql += "NOT NULL "
|
||||
}
|
||||
}
|
||||
|
||||
if col.Default != "" {
|
||||
sql += "DEFAULT " + col.Default + " "
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func (col *Column) StringNoPk(d Dialect) string {
|
||||
sql := d.QuoteStr() + col.Name + d.QuoteStr() + " "
|
||||
|
||||
sql += d.SqlType(col) + " "
|
||||
|
||||
if d.ShowCreateNull() {
|
||||
if col.Nullable {
|
||||
sql += "NULL "
|
||||
} else {
|
||||
sql += "NOT NULL "
|
||||
}
|
||||
}
|
||||
|
||||
if col.Default != "" {
|
||||
sql += "DEFAULT " + col.Default + " "
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
// return col's filed of struct's value
|
||||
func (col *Column) ValueOf(bean interface{}) (*reflect.Value, error) {
|
||||
dataStruct := reflect.Indirect(reflect.ValueOf(bean))
|
||||
return col.ValueOfV(&dataStruct)
|
||||
}
|
||||
|
||||
func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) {
|
||||
var fieldValue reflect.Value
|
||||
fieldPath := strings.Split(col.FieldName, ".")
|
||||
|
||||
if dataStruct.Type().Kind() == reflect.Map {
|
||||
keyValue := reflect.ValueOf(fieldPath[len(fieldPath)-1])
|
||||
fieldValue = dataStruct.MapIndex(keyValue)
|
||||
return &fieldValue, nil
|
||||
} else if dataStruct.Type().Kind() == reflect.Interface {
|
||||
structValue := reflect.ValueOf(dataStruct.Interface())
|
||||
dataStruct = &structValue
|
||||
}
|
||||
|
||||
level := len(fieldPath)
|
||||
fieldValue = dataStruct.FieldByName(fieldPath[0])
|
||||
for i := 0; i < level-1; i++ {
|
||||
if !fieldValue.IsValid() {
|
||||
break
|
||||
}
|
||||
if fieldValue.Kind() == reflect.Struct {
|
||||
fieldValue = fieldValue.FieldByName(fieldPath[i+1])
|
||||
} else if fieldValue.Kind() == reflect.Ptr {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldValue.Type().Elem()))
|
||||
}
|
||||
fieldValue = fieldValue.Elem().FieldByName(fieldPath[i+1])
|
||||
} else {
|
||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if !fieldValue.IsValid() {
|
||||
return nil, fmt.Errorf("field %v is not valid", col.FieldName)
|
||||
}
|
||||
|
||||
return &fieldValue, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
package core
|
||||
|
||||
// Conversion is an interface. A type implements Conversion will according
|
||||
// the custom method to fill into database and retrieve from database.
|
||||
type Conversion interface {
|
||||
FromDB([]byte) error
|
||||
ToDB() ([]byte, error)
|
||||
}
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
func MapToSlice(query string, mp interface{}) (string, []interface{}, error) {
|
||||
vv := reflect.ValueOf(mp)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return "", []interface{}{}, ErrNoMapPointer
|
||||
}
|
||||
|
||||
args := make([]interface{}, 0, len(vv.Elem().MapKeys()))
|
||||
var err error
|
||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||
v := vv.Elem().MapIndex(reflect.ValueOf(src[1:]))
|
||||
if !v.IsValid() {
|
||||
err = fmt.Errorf("map key %s is missing", src[1:])
|
||||
} else {
|
||||
args = append(args, v.Interface())
|
||||
}
|
||||
return "?"
|
||||
})
|
||||
|
||||
return query, args, err
|
||||
}
|
||||
|
||||
func StructToSlice(query string, st interface{}) (string, []interface{}, error) {
|
||||
vv := reflect.ValueOf(st)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return "", []interface{}{}, ErrNoStructPointer
|
||||
}
|
||||
|
||||
args := make([]interface{}, 0)
|
||||
var err error
|
||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||
fv := vv.Elem().FieldByName(src[1:]).Interface()
|
||||
if v, ok := fv.(driver.Valuer); ok {
|
||||
var value driver.Value
|
||||
value, err = v.Value()
|
||||
if err != nil {
|
||||
return "?"
|
||||
}
|
||||
args = append(args, value)
|
||||
} else {
|
||||
args = append(args, fv)
|
||||
}
|
||||
return "?"
|
||||
})
|
||||
if err != nil {
|
||||
return "", []interface{}{}, err
|
||||
}
|
||||
return query, args, nil
|
||||
}
|
||||
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
Mapper IMapper
|
||||
}
|
||||
|
||||
func Open(driverName, dataSourceName string) (*DB, error) {
|
||||
db, err := sql.Open(driverName, dataSourceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil
|
||||
}
|
||||
|
||||
func FromDB(db *sql.DB) *DB {
|
||||
return &DB{db, NewCacheMapper(&SnakeMapper{})}
|
||||
}
|
||||
|
||||
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
|
||||
rows, err := db.DB.Query(query, args...)
|
||||
if err != nil {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{rows, db.Mapper}, nil
|
||||
}
|
||||
|
||||
func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.Query(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.Query(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return &Row{rows, nil}
|
||||
}
|
||||
|
||||
func (db *DB) QueryRowMap(query string, mp interface{}) *Row {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return db.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) QueryRowStruct(query string, st interface{}) *Row {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return db.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Mapper IMapper
|
||||
names map[string]int
|
||||
}
|
||||
|
||||
func (db *DB) Prepare(query string) (*Stmt, error) {
|
||||
names := make(map[string]int)
|
||||
var i int
|
||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||
names[src[1:]] = i
|
||||
i += 1
|
||||
return "?"
|
||||
})
|
||||
|
||||
stmt, err := db.DB.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Stmt{stmt, db.Mapper, names}, nil
|
||||
}
|
||||
|
||||
func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
|
||||
vv := reflect.ValueOf(mp)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return nil, errors.New("mp should be a map's pointer")
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
for k, i := range s.names {
|
||||
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
|
||||
}
|
||||
return s.Stmt.Exec(args...)
|
||||
}
|
||||
|
||||
func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
|
||||
vv := reflect.ValueOf(st)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return nil, errors.New("mp should be a map's pointer")
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
for k, i := range s.names {
|
||||
args[i] = vv.Elem().FieldByName(k).Interface()
|
||||
}
|
||||
return s.Stmt.Exec(args...)
|
||||
}
|
||||
|
||||
func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
|
||||
rows, err := s.Stmt.Query(args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{rows, s.Mapper}, nil
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
|
||||
vv := reflect.ValueOf(mp)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return nil, errors.New("mp should be a map's pointer")
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
for k, i := range s.names {
|
||||
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
|
||||
}
|
||||
|
||||
return s.Query(args...)
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
|
||||
vv := reflect.ValueOf(st)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return nil, errors.New("mp should be a map's pointer")
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
for k, i := range s.names {
|
||||
args[i] = vv.Elem().FieldByName(k).Interface()
|
||||
}
|
||||
|
||||
return s.Query(args...)
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryRow(args ...interface{}) *Row {
|
||||
rows, err := s.Query(args...)
|
||||
return &Row{rows, err}
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryRowMap(mp interface{}) *Row {
|
||||
vv := reflect.ValueOf(mp)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return &Row{nil, errors.New("mp should be a map's pointer")}
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
for k, i := range s.names {
|
||||
args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
|
||||
}
|
||||
|
||||
return s.QueryRow(args...)
|
||||
}
|
||||
|
||||
func (s *Stmt) QueryRowStruct(st interface{}) *Row {
|
||||
vv := reflect.ValueOf(st)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return &Row{nil, errors.New("st should be a struct's pointer")}
|
||||
}
|
||||
|
||||
args := make([]interface{}, len(s.names))
|
||||
for k, i := range s.names {
|
||||
args[i] = vv.Elem().FieldByName(k).Interface()
|
||||
}
|
||||
|
||||
return s.QueryRow(args...)
|
||||
}
|
||||
|
||||
var (
|
||||
re = regexp.MustCompile(`[?](\w+)`)
|
||||
)
|
||||
|
||||
// insert into (name) values (?)
|
||||
// insert into (name) values (?name)
|
||||
func (db *DB) ExecMap(query string, mp interface{}) (sql.Result, error) {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.DB.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (db *DB) ExecStruct(query string, st interface{}) (sql.Result, error) {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.DB.Exec(query, args...)
|
||||
}
|
||||
|
||||
type EmptyScanner struct {
|
||||
}
|
||||
|
||||
func (EmptyScanner) Scan(src interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tx struct {
|
||||
*sql.Tx
|
||||
Mapper IMapper
|
||||
}
|
||||
|
||||
func (db *DB) Begin() (*Tx, error) {
|
||||
tx, err := db.DB.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{tx, db.Mapper}, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) Prepare(query string) (*Stmt, error) {
|
||||
names := make(map[string]int)
|
||||
var i int
|
||||
query = re.ReplaceAllStringFunc(query, func(src string) string {
|
||||
names[src[1:]] = i
|
||||
i += 1
|
||||
return "?"
|
||||
})
|
||||
|
||||
stmt, err := tx.Tx.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Stmt{stmt, tx.Mapper, names}, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
|
||||
// TODO:
|
||||
return stmt
|
||||
}
|
||||
|
||||
func (tx *Tx) ExecMap(query string, mp interface{}) (sql.Result, error) {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Tx.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) ExecStruct(query string, st interface{}) (sql.Result, error) {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Tx.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
|
||||
rows, err := tx.Tx.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{rows, tx.Mapper}, nil
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryMap(query string, mp interface{}) (*Rows, error) {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Query(query, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tx.Query(query, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
|
||||
rows, err := tx.Query(query, args...)
|
||||
return &Row{rows, err}
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row {
|
||||
query, args, err := MapToSlice(query, mp)
|
||||
if err != nil {
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return tx.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row {
|
||||
query, args, err := StructToSlice(query, st)
|
||||
if err != nil {
|
||||
return &Row{nil, err}
|
||||
}
|
||||
return tx.QueryRow(query, args...)
|
||||
}
|
||||
|
|
@ -0,0 +1,659 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
var (
|
||||
//dbtype string = "sqlite3"
|
||||
dbtype string = "mysql"
|
||||
createTableSql string
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id int64
|
||||
Name string
|
||||
Title string
|
||||
Age float32
|
||||
Alias string
|
||||
NickName string
|
||||
Created NullTime
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch dbtype {
|
||||
case "sqlite3":
|
||||
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `name` TEXT NULL, " +
|
||||
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
|
||||
case "mysql":
|
||||
createTableSql = "CREATE TABLE IF NOT EXISTS `user` (`id` INTEGER PRIMARY KEY AUTO_INCREMENT NOT NULL, `name` TEXT NULL, " +
|
||||
"`title` TEXT NULL, `age` FLOAT NULL, `alias` TEXT NULL, `nick_name` TEXT NULL, `created` datetime);"
|
||||
default:
|
||||
panic("no db type")
|
||||
}
|
||||
}
|
||||
|
||||
func testOpen() (*DB, error) {
|
||||
switch dbtype {
|
||||
case "sqlite3":
|
||||
os.Remove("./test.db")
|
||||
return Open("sqlite3", "./test.db")
|
||||
case "mysql":
|
||||
return Open("mysql", "root:@/core_test?charset=utf8")
|
||||
default:
|
||||
panic("no db type")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOriQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var Id int64
|
||||
var Name, Title, Alias, NickName string
|
||||
var Age float32
|
||||
var Created NullTime
|
||||
err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
//fmt.Println(Id, Name, Title, Age, Alias, NickName)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStructQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?, ?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var user User
|
||||
err = rows.ScanStructByIndex(&user)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if user.Name != "xlw" {
|
||||
fmt.Println(user)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStruct2Query(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name, created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
db.Mapper = NewCacheMapper(&SnakeMapper{})
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var user User
|
||||
err = rows.ScanStructByName(&user)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if user.Name != "xlw" {
|
||||
fmt.Println(user)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSliceInterfaceQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
slice := make([]interface{}, len(cols))
|
||||
err = rows.ScanSlice(&slice)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
fmt.Println(slice)
|
||||
if *slice[1].(*string) != "xlw" {
|
||||
fmt.Println(slice)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
/*func BenchmarkSliceBytesQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
os.Remove("./test.db")
|
||||
db, err := Open("sqlite3", "./test.db")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
slice := make([][]byte, len(cols))
|
||||
err = rows.ScanSlice(&slice)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if string(slice[1]) != "xlw" {
|
||||
fmt.Println(slice)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
func BenchmarkSliceStringQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (name, title, age, alias, nick_name, created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
slice := make([]*string, len(cols))
|
||||
err = rows.ScanSlice(&slice)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if (*slice[1]) != "xlw" {
|
||||
fmt.Println(slice)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMapInterfaceQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
m := make(map[string]interface{})
|
||||
err = rows.ScanMap(&m)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if m["name"].(string) != "xlw" {
|
||||
fmt.Println(m)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
/*func BenchmarkMapBytesQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
os.Remove("./test.db")
|
||||
db, err := Open("sqlite3", "./test.db")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
m := make(map[string][]byte)
|
||||
err = rows.ScanMap(&m)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if string(m["name"]) != "xlw" {
|
||||
fmt.Println(m)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
*/
|
||||
/*
|
||||
func BenchmarkMapStringQuery(b *testing.B) {
|
||||
b.StopTimer()
|
||||
os.Remove("./test.db")
|
||||
db, err := Open("sqlite3", "./test.db")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, err = db.Exec("insert into user (name, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
m := make(map[string]string)
|
||||
err = rows.ScanMap(&m)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
if m["name"] != "xlw" {
|
||||
fmt.Println(m)
|
||||
b.Error(errors.New("name should be xlw"))
|
||||
}
|
||||
}
|
||||
|
||||
rows.Close()
|
||||
}
|
||||
}*/
|
||||
|
||||
func BenchmarkExec(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = db.Exec("insert into user (`name`, title, age, alias, nick_name,created) values (?,?,?,?,?,?)",
|
||||
"xlw", "tester", 1.2, "lunny", "lunny xiao", time.Now())
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecMap(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
mp := map[string]interface{}{
|
||||
"name": "xlw",
|
||||
"title": "tester",
|
||||
"age": 1.2,
|
||||
"alias": "lunny",
|
||||
"nick_name": "lunny xiao",
|
||||
"created": time.Now(),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name, created) "+
|
||||
"values (?name,?title,?age,?alias,?nick_name,?created)",
|
||||
&mp)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecMap(t *testing.T) {
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
mp := map[string]interface{}{
|
||||
"name": "xlw",
|
||||
"title": "tester",
|
||||
"age": 1.2,
|
||||
"alias": "lunny",
|
||||
"nick_name": "lunny xiao",
|
||||
"created": time.Now(),
|
||||
}
|
||||
|
||||
_, err = db.ExecMap("insert into user (`name`, title, age, alias, nick_name,created) "+
|
||||
"values (?name,?title,?age,?alias,?nick_name,?created)",
|
||||
&mp)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rows, err := db.Query("select * from user")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var user User
|
||||
err = rows.ScanStructByName(&user)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println("--", user)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecStruct(t *testing.T) {
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
user := User{Name: "xlw",
|
||||
Title: "tester",
|
||||
Age: 1.2,
|
||||
Alias: "lunny",
|
||||
NickName: "lunny xiao",
|
||||
Created: NullTime(time.Now()),
|
||||
}
|
||||
|
||||
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
|
||||
"values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
|
||||
&user)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rows, err := db.QueryStruct("select * from user where `name` = ?Name", &user)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var user User
|
||||
err = rows.ScanStructByName(&user)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println("1--", user)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExecStruct(b *testing.B) {
|
||||
b.StopTimer()
|
||||
db, err := testOpen()
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec(createTableSql)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
|
||||
user := User{Name: "xlw",
|
||||
Title: "tester",
|
||||
Age: 1.2,
|
||||
Alias: "lunny",
|
||||
NickName: "lunny xiao",
|
||||
Created: NullTime(time.Now()),
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+
|
||||
"values (?Name,?Title,?Age,?Alias,?NickName,?Created)",
|
||||
&user)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,307 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DbType string
|
||||
|
||||
type Uri struct {
|
||||
DbType DbType
|
||||
Proto string
|
||||
Host string
|
||||
Port string
|
||||
DbName string
|
||||
User string
|
||||
Passwd string
|
||||
Charset string
|
||||
Laddr string
|
||||
Raddr string
|
||||
Timeout time.Duration
|
||||
Schema string
|
||||
}
|
||||
|
||||
// a dialect is a driver's wrapper
|
||||
type Dialect interface {
|
||||
SetLogger(logger ILogger)
|
||||
Init(*DB, *Uri, string, string) error
|
||||
URI() *Uri
|
||||
DB() *DB
|
||||
DBType() DbType
|
||||
SqlType(*Column) string
|
||||
FormatBytes(b []byte) string
|
||||
|
||||
DriverName() string
|
||||
DataSourceName() string
|
||||
|
||||
QuoteStr() string
|
||||
IsReserved(string) bool
|
||||
Quote(string) string
|
||||
AndStr() string
|
||||
OrStr() string
|
||||
EqStr() string
|
||||
RollBackStr() string
|
||||
AutoIncrStr() string
|
||||
|
||||
SupportInsertMany() bool
|
||||
SupportEngine() bool
|
||||
SupportCharset() bool
|
||||
SupportDropIfExists() bool
|
||||
IndexOnTable() bool
|
||||
ShowCreateNull() bool
|
||||
|
||||
IndexCheckSql(tableName, idxName string) (string, []interface{})
|
||||
TableCheckSql(tableName string) (string, []interface{})
|
||||
|
||||
IsColumnExist(tableName string, colName string) (bool, error)
|
||||
|
||||
CreateTableSql(table *Table, tableName, storeEngine, charset string) string
|
||||
DropTableSql(tableName string) string
|
||||
CreateIndexSql(tableName string, index *Index) string
|
||||
DropIndexSql(tableName string, index *Index) string
|
||||
|
||||
ModifyColumnSql(tableName string, col *Column) string
|
||||
|
||||
ForUpdateSql(query string) string
|
||||
|
||||
//CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
|
||||
//MustDropTable(tableName string) error
|
||||
|
||||
GetColumns(tableName string) ([]string, map[string]*Column, error)
|
||||
GetTables() ([]*Table, error)
|
||||
GetIndexes(tableName string) (map[string]*Index, error)
|
||||
|
||||
Filters() []Filter
|
||||
}
|
||||
|
||||
func OpenDialect(dialect Dialect) (*DB, error) {
|
||||
return Open(dialect.DriverName(), dialect.DataSourceName())
|
||||
}
|
||||
|
||||
type Base struct {
|
||||
db *DB
|
||||
dialect Dialect
|
||||
driverName string
|
||||
dataSourceName string
|
||||
logger ILogger
|
||||
*Uri
|
||||
}
|
||||
|
||||
func (b *Base) DB() *DB {
|
||||
return b.db
|
||||
}
|
||||
|
||||
func (b *Base) SetLogger(logger ILogger) {
|
||||
b.logger = logger
|
||||
}
|
||||
|
||||
func (b *Base) Init(db *DB, dialect Dialect, uri *Uri, drivername, dataSourceName string) error {
|
||||
b.db, b.dialect, b.Uri = db, dialect, uri
|
||||
b.driverName, b.dataSourceName = drivername, dataSourceName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Base) URI() *Uri {
|
||||
return b.Uri
|
||||
}
|
||||
|
||||
func (b *Base) DBType() DbType {
|
||||
return b.Uri.DbType
|
||||
}
|
||||
|
||||
func (b *Base) FormatBytes(bs []byte) string {
|
||||
return fmt.Sprintf("0x%x", bs)
|
||||
}
|
||||
|
||||
func (b *Base) DriverName() string {
|
||||
return b.driverName
|
||||
}
|
||||
|
||||
func (b *Base) ShowCreateNull() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *Base) DataSourceName() string {
|
||||
return b.dataSourceName
|
||||
}
|
||||
|
||||
func (b *Base) AndStr() string {
|
||||
return "AND"
|
||||
}
|
||||
|
||||
func (b *Base) OrStr() string {
|
||||
return "OR"
|
||||
}
|
||||
|
||||
func (b *Base) EqStr() string {
|
||||
return "="
|
||||
}
|
||||
|
||||
func (db *Base) RollBackStr() string {
|
||||
return "ROLL BACK"
|
||||
}
|
||||
|
||||
func (db *Base) SupportDropIfExists() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (db *Base) DropTableSql(tableName string) string {
|
||||
return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName)
|
||||
}
|
||||
|
||||
func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
|
||||
db.LogSQL(query, args)
|
||||
rows, err := db.DB().Query(query, args...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (db *Base) IsColumnExist(tableName, colName string) (bool, error) {
|
||||
query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
|
||||
query = strings.Replace(query, "`", db.dialect.QuoteStr(), -1)
|
||||
return db.HasRecords(query, db.DbName, tableName, colName)
|
||||
}
|
||||
|
||||
/*
|
||||
func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
|
||||
sql, args := db.dialect.TableCheckSql(tableName)
|
||||
rows, err := db.DB().Query(sql, args...)
|
||||
if db.Logger != nil {
|
||||
db.Logger.Info("[sql]", sql, args)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
return nil
|
||||
}
|
||||
|
||||
sql = db.dialect.CreateTableSql(table, tableName, storeEngine, charset)
|
||||
_, err = db.DB().Exec(sql)
|
||||
if db.Logger != nil {
|
||||
db.Logger.Info("[sql]", sql)
|
||||
}
|
||||
return err
|
||||
}*/
|
||||
|
||||
func (db *Base) CreateIndexSql(tableName string, index *Index) string {
|
||||
quote := db.dialect.Quote
|
||||
var unique string
|
||||
var idxName string
|
||||
if index.Type == UniqueType {
|
||||
unique = " UNIQUE"
|
||||
}
|
||||
idxName = index.XName(tableName)
|
||||
return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
|
||||
quote(idxName), quote(tableName),
|
||||
quote(strings.Join(index.Cols, quote(","))))
|
||||
}
|
||||
|
||||
func (db *Base) DropIndexSql(tableName string, index *Index) string {
|
||||
quote := db.dialect.Quote
|
||||
var name string
|
||||
if index.IsRegular {
|
||||
name = index.XName(tableName)
|
||||
} else {
|
||||
name = index.Name
|
||||
}
|
||||
return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
|
||||
}
|
||||
|
||||
func (db *Base) ModifyColumnSql(tableName string, col *Column) string {
|
||||
return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, col.StringNoPk(db.dialect))
|
||||
}
|
||||
|
||||
func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset string) string {
|
||||
var sql string
|
||||
sql = "CREATE TABLE IF NOT EXISTS "
|
||||
if tableName == "" {
|
||||
tableName = table.Name
|
||||
}
|
||||
|
||||
sql += b.dialect.Quote(tableName)
|
||||
sql += " ("
|
||||
|
||||
if len(table.ColumnsSeq()) > 0 {
|
||||
pkList := table.PrimaryKeys
|
||||
|
||||
for _, colName := range table.ColumnsSeq() {
|
||||
col := table.GetColumn(colName)
|
||||
if col.IsPrimaryKey && len(pkList) == 1 {
|
||||
sql += col.String(b.dialect)
|
||||
} else {
|
||||
sql += col.StringNoPk(b.dialect)
|
||||
}
|
||||
sql = strings.TrimSpace(sql)
|
||||
sql += ", "
|
||||
}
|
||||
|
||||
if len(pkList) > 1 {
|
||||
sql += "PRIMARY KEY ( "
|
||||
sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(",")))
|
||||
sql += " ), "
|
||||
}
|
||||
|
||||
sql = sql[:len(sql)-2]
|
||||
}
|
||||
sql += ")"
|
||||
|
||||
if b.dialect.SupportEngine() && storeEngine != "" {
|
||||
sql += " ENGINE=" + storeEngine
|
||||
}
|
||||
if b.dialect.SupportCharset() {
|
||||
if len(charset) == 0 {
|
||||
charset = b.dialect.URI().Charset
|
||||
}
|
||||
if len(charset) > 0 {
|
||||
sql += " DEFAULT CHARSET " + charset
|
||||
}
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
func (b *Base) ForUpdateSql(query string) string {
|
||||
return query + " FOR UPDATE"
|
||||
}
|
||||
|
||||
func (b *Base) LogSQL(sql string, args []interface{}) {
|
||||
if b.logger != nil && b.logger.IsShowSQL() {
|
||||
if len(args) > 0 {
|
||||
b.logger.Infof("[SQL] %v %v", sql, args)
|
||||
} else {
|
||||
b.logger.Infof("[SQL] %v", sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
dialects = map[string]func() Dialect{}
|
||||
)
|
||||
|
||||
// RegisterDialect register database dialect
|
||||
func RegisterDialect(dbName DbType, dialectFunc func() Dialect) {
|
||||
if dialectFunc == nil {
|
||||
panic("core: Register dialect is nil")
|
||||
}
|
||||
dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
|
||||
}
|
||||
|
||||
// QueryDialect query if registed database dialect
|
||||
func QueryDialect(dbName DbType) Dialect {
|
||||
if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
|
||||
return d()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package core
|
||||
|
||||
type Driver interface {
|
||||
Parse(string, string) (*Uri, error)
|
||||
}
|
||||
|
||||
var (
|
||||
drivers = map[string]Driver{}
|
||||
)
|
||||
|
||||
func RegisterDriver(driverName string, driver Driver) {
|
||||
if driver == nil {
|
||||
panic("core: Register driver is nil")
|
||||
}
|
||||
if _, dup := drivers[driverName]; dup {
|
||||
panic("core: Register called twice for driver " + driverName)
|
||||
}
|
||||
drivers[driverName] = driver
|
||||
}
|
||||
|
||||
func QueryDriver(driverName string) Driver {
|
||||
return drivers[driverName]
|
||||
}
|
||||
|
||||
func RegisteredDriverSize() int {
|
||||
return len(drivers)
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
package core
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoMapPointer = errors.New("mp should be a map's pointer")
|
||||
ErrNoStructPointer = errors.New("mp should be a struct's pointer")
|
||||
)
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Filter is an interface to filter SQL
|
||||
type Filter interface {
|
||||
Do(sql string, dialect Dialect, table *Table) string
|
||||
}
|
||||
|
||||
// QuoteFilter filter SQL replace ` to database's own quote character
|
||||
type QuoteFilter struct {
|
||||
}
|
||||
|
||||
func (s *QuoteFilter) Do(sql string, dialect Dialect, table *Table) string {
|
||||
return strings.Replace(sql, "`", dialect.QuoteStr(), -1)
|
||||
}
|
||||
|
||||
// IdFilter filter SQL replace (id) to primary key column name
|
||||
type IdFilter struct {
|
||||
}
|
||||
|
||||
type Quoter struct {
|
||||
dialect Dialect
|
||||
}
|
||||
|
||||
func NewQuoter(dialect Dialect) *Quoter {
|
||||
return &Quoter{dialect}
|
||||
}
|
||||
|
||||
func (q *Quoter) Quote(content string) string {
|
||||
return q.dialect.QuoteStr() + content + q.dialect.QuoteStr()
|
||||
}
|
||||
|
||||
func (i *IdFilter) Do(sql string, dialect Dialect, table *Table) string {
|
||||
quoter := NewQuoter(dialect)
|
||||
if table != nil && len(table.PrimaryKeys) == 1 {
|
||||
sql = strings.Replace(sql, "`(id)`", quoter.Quote(table.PrimaryKeys[0]), -1)
|
||||
sql = strings.Replace(sql, quoter.Quote("(id)"), quoter.Quote(table.PrimaryKeys[0]), -1)
|
||||
return strings.Replace(sql, "(id)", quoter.Quote(table.PrimaryKeys[0]), -1)
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
// SeqFilter filter SQL replace ?, ? ... to $1, $2 ...
|
||||
type SeqFilter struct {
|
||||
Prefix string
|
||||
Start int
|
||||
}
|
||||
|
||||
func (s *SeqFilter) Do(sql string, dialect Dialect, table *Table) string {
|
||||
segs := strings.Split(sql, "?")
|
||||
size := len(segs)
|
||||
res := ""
|
||||
for i, c := range segs {
|
||||
if i < size-1 {
|
||||
res += c + fmt.Sprintf("%s%v", s.Prefix, i+s.Start)
|
||||
}
|
||||
}
|
||||
res += segs[size-1]
|
||||
return res
|
||||
}
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
package core
|
||||
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
// !nashtsai! following level also match syslog.Priority value
|
||||
LOG_DEBUG LogLevel = iota
|
||||
LOG_INFO
|
||||
LOG_WARNING
|
||||
LOG_ERR
|
||||
LOG_OFF
|
||||
LOG_UNKNOWN
|
||||
)
|
||||
|
||||
// logger interface
|
||||
type ILogger interface {
|
||||
Debug(v ...interface{})
|
||||
Debugf(format string, v ...interface{})
|
||||
Error(v ...interface{})
|
||||
Errorf(format string, v ...interface{})
|
||||
Info(v ...interface{})
|
||||
Infof(format string, v ...interface{})
|
||||
Warn(v ...interface{})
|
||||
Warnf(format string, v ...interface{})
|
||||
|
||||
Level() LogLevel
|
||||
SetLevel(l LogLevel)
|
||||
|
||||
ShowSQL(show ...bool)
|
||||
IsShowSQL() bool
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
IndexType = iota + 1
|
||||
UniqueType
|
||||
)
|
||||
|
||||
// database index
|
||||
type Index struct {
|
||||
IsRegular bool
|
||||
Name string
|
||||
Type int
|
||||
Cols []string
|
||||
}
|
||||
|
||||
func (index *Index) XName(tableName string) string {
|
||||
if !strings.HasPrefix(index.Name, "UQE_") &&
|
||||
!strings.HasPrefix(index.Name, "IDX_") {
|
||||
if index.Type == UniqueType {
|
||||
return fmt.Sprintf("UQE_%v_%v", tableName, index.Name)
|
||||
}
|
||||
return fmt.Sprintf("IDX_%v_%v", tableName, index.Name)
|
||||
}
|
||||
return index.Name
|
||||
}
|
||||
|
||||
// add columns which will be composite index
|
||||
func (index *Index) AddColumn(cols ...string) {
|
||||
for _, col := range cols {
|
||||
index.Cols = append(index.Cols, col)
|
||||
}
|
||||
}
|
||||
|
||||
func (index *Index) Equal(dst *Index) bool {
|
||||
if index.Type != dst.Type {
|
||||
return false
|
||||
}
|
||||
if len(index.Cols) != len(dst.Cols) {
|
||||
return false
|
||||
}
|
||||
sort.StringSlice(index.Cols).Sort()
|
||||
sort.StringSlice(dst.Cols).Sort()
|
||||
|
||||
for i := 0; i < len(index.Cols); i++ {
|
||||
if index.Cols[i] != dst.Cols[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// new an index
|
||||
func NewIndex(name string, indexType int) *Index {
|
||||
return &Index{true, name, indexType, make([]string, 0)}
|
||||
}
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// name translation between struct, fields names and table, column names
|
||||
type IMapper interface {
|
||||
Obj2Table(string) string
|
||||
Table2Obj(string) string
|
||||
}
|
||||
|
||||
type CacheMapper struct {
|
||||
oriMapper IMapper
|
||||
obj2tableCache map[string]string
|
||||
obj2tableMutex sync.RWMutex
|
||||
table2objCache map[string]string
|
||||
table2objMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewCacheMapper(mapper IMapper) *CacheMapper {
|
||||
return &CacheMapper{oriMapper: mapper, obj2tableCache: make(map[string]string),
|
||||
table2objCache: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *CacheMapper) Obj2Table(o string) string {
|
||||
m.obj2tableMutex.RLock()
|
||||
t, ok := m.obj2tableCache[o]
|
||||
m.obj2tableMutex.RUnlock()
|
||||
if ok {
|
||||
return t
|
||||
}
|
||||
|
||||
t = m.oriMapper.Obj2Table(o)
|
||||
m.obj2tableMutex.Lock()
|
||||
m.obj2tableCache[o] = t
|
||||
m.obj2tableMutex.Unlock()
|
||||
return t
|
||||
}
|
||||
|
||||
func (m *CacheMapper) Table2Obj(t string) string {
|
||||
m.table2objMutex.RLock()
|
||||
o, ok := m.table2objCache[t]
|
||||
m.table2objMutex.RUnlock()
|
||||
if ok {
|
||||
return o
|
||||
}
|
||||
|
||||
o = m.oriMapper.Table2Obj(t)
|
||||
m.table2objMutex.Lock()
|
||||
m.table2objCache[t] = o
|
||||
m.table2objMutex.Unlock()
|
||||
return o
|
||||
}
|
||||
|
||||
// SameMapper implements IMapper and provides same name between struct and
|
||||
// database table
|
||||
type SameMapper struct {
|
||||
}
|
||||
|
||||
func (m SameMapper) Obj2Table(o string) string {
|
||||
return o
|
||||
}
|
||||
|
||||
func (m SameMapper) Table2Obj(t string) string {
|
||||
return t
|
||||
}
|
||||
|
||||
// SnakeMapper implements IMapper and provides name transaltion between
|
||||
// struct and database table
|
||||
type SnakeMapper struct {
|
||||
}
|
||||
|
||||
func snakeCasedName(name string) string {
|
||||
newstr := make([]rune, 0)
|
||||
for idx, chr := range name {
|
||||
if isUpper := 'A' <= chr && chr <= 'Z'; isUpper {
|
||||
if idx > 0 {
|
||||
newstr = append(newstr, '_')
|
||||
}
|
||||
chr -= ('A' - 'a')
|
||||
}
|
||||
newstr = append(newstr, chr)
|
||||
}
|
||||
|
||||
return string(newstr)
|
||||
}
|
||||
|
||||
func (mapper SnakeMapper) Obj2Table(name string) string {
|
||||
return snakeCasedName(name)
|
||||
}
|
||||
|
||||
func titleCasedName(name string) string {
|
||||
newstr := make([]rune, 0)
|
||||
upNextChar := true
|
||||
|
||||
name = strings.ToLower(name)
|
||||
|
||||
for _, chr := range name {
|
||||
switch {
|
||||
case upNextChar:
|
||||
upNextChar = false
|
||||
if 'a' <= chr && chr <= 'z' {
|
||||
chr -= ('a' - 'A')
|
||||
}
|
||||
case chr == '_':
|
||||
upNextChar = true
|
||||
continue
|
||||
}
|
||||
|
||||
newstr = append(newstr, chr)
|
||||
}
|
||||
|
||||
return string(newstr)
|
||||
}
|
||||
|
||||
func (mapper SnakeMapper) Table2Obj(name string) string {
|
||||
return titleCasedName(name)
|
||||
}
|
||||
|
||||
// GonicMapper implements IMapper. It will consider initialisms when mapping names.
|
||||
// E.g. id -> ID, user -> User and to table names: UserID -> user_id, MyUID -> my_uid
|
||||
type GonicMapper map[string]bool
|
||||
|
||||
func isASCIIUpper(r rune) bool {
|
||||
return 'A' <= r && r <= 'Z'
|
||||
}
|
||||
|
||||
func toASCIIUpper(r rune) rune {
|
||||
if 'a' <= r && r <= 'z' {
|
||||
r -= ('a' - 'A')
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func gonicCasedName(name string) string {
|
||||
newstr := make([]rune, 0, len(name)+3)
|
||||
for idx, chr := range name {
|
||||
if isASCIIUpper(chr) && idx > 0 {
|
||||
if !isASCIIUpper(newstr[len(newstr)-1]) {
|
||||
newstr = append(newstr, '_')
|
||||
}
|
||||
}
|
||||
|
||||
if !isASCIIUpper(chr) && idx > 1 {
|
||||
l := len(newstr)
|
||||
if isASCIIUpper(newstr[l-1]) && isASCIIUpper(newstr[l-2]) {
|
||||
newstr = append(newstr, newstr[l-1])
|
||||
newstr[l-1] = '_'
|
||||
}
|
||||
}
|
||||
|
||||
newstr = append(newstr, chr)
|
||||
}
|
||||
return strings.ToLower(string(newstr))
|
||||
}
|
||||
|
||||
func (mapper GonicMapper) Obj2Table(name string) string {
|
||||
return gonicCasedName(name)
|
||||
}
|
||||
|
||||
func (mapper GonicMapper) Table2Obj(name string) string {
|
||||
newstr := make([]rune, 0)
|
||||
|
||||
name = strings.ToLower(name)
|
||||
parts := strings.Split(name, "_")
|
||||
|
||||
for _, p := range parts {
|
||||
_, isInitialism := mapper[strings.ToUpper(p)]
|
||||
for i, r := range p {
|
||||
if i == 0 || isInitialism {
|
||||
r = toASCIIUpper(r)
|
||||
}
|
||||
newstr = append(newstr, r)
|
||||
}
|
||||
}
|
||||
|
||||
return string(newstr)
|
||||
}
|
||||
|
||||
// A GonicMapper that contains a list of common initialisms taken from golang/lint
|
||||
var LintGonicMapper = GonicMapper{
|
||||
"API": true,
|
||||
"ASCII": true,
|
||||
"CPU": true,
|
||||
"CSS": true,
|
||||
"DNS": true,
|
||||
"EOF": true,
|
||||
"GUID": true,
|
||||
"HTML": true,
|
||||
"HTTP": true,
|
||||
"HTTPS": true,
|
||||
"ID": true,
|
||||
"IP": true,
|
||||
"JSON": true,
|
||||
"LHS": true,
|
||||
"QPS": true,
|
||||
"RAM": true,
|
||||
"RHS": true,
|
||||
"RPC": true,
|
||||
"SLA": true,
|
||||
"SMTP": true,
|
||||
"SSH": true,
|
||||
"TLS": true,
|
||||
"TTL": true,
|
||||
"UI": true,
|
||||
"UID": true,
|
||||
"UUID": true,
|
||||
"URI": true,
|
||||
"URL": true,
|
||||
"UTF8": true,
|
||||
"VM": true,
|
||||
"XML": true,
|
||||
"XSRF": true,
|
||||
"XSS": true,
|
||||
}
|
||||
|
||||
// provide prefix table name support
|
||||
type PrefixMapper struct {
|
||||
Mapper IMapper
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (mapper PrefixMapper) Obj2Table(name string) string {
|
||||
return mapper.Prefix + mapper.Mapper.Obj2Table(name)
|
||||
}
|
||||
|
||||
func (mapper PrefixMapper) Table2Obj(name string) string {
|
||||
return mapper.Mapper.Table2Obj(name[len(mapper.Prefix):])
|
||||
}
|
||||
|
||||
func NewPrefixMapper(mapper IMapper, prefix string) PrefixMapper {
|
||||
return PrefixMapper{mapper, prefix}
|
||||
}
|
||||
|
||||
// provide suffix table name support
|
||||
type SuffixMapper struct {
|
||||
Mapper IMapper
|
||||
Suffix string
|
||||
}
|
||||
|
||||
func (mapper SuffixMapper) Obj2Table(name string) string {
|
||||
return mapper.Mapper.Obj2Table(name) + mapper.Suffix
|
||||
}
|
||||
|
||||
func (mapper SuffixMapper) Table2Obj(name string) string {
|
||||
return mapper.Mapper.Table2Obj(name[:len(name)-len(mapper.Suffix)])
|
||||
}
|
||||
|
||||
func NewSuffixMapper(mapper IMapper, suffix string) SuffixMapper {
|
||||
return SuffixMapper{mapper, suffix}
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGonicMapperFromObj(t *testing.T) {
|
||||
testCases := map[string]string{
|
||||
"HTTPLib": "http_lib",
|
||||
"id": "id",
|
||||
"ID": "id",
|
||||
"IDa": "i_da",
|
||||
"iDa": "i_da",
|
||||
"IDAa": "id_aa",
|
||||
"aID": "a_id",
|
||||
"aaID": "aa_id",
|
||||
"aaaID": "aaa_id",
|
||||
"MyREalFunkYLONgNAME": "my_r_eal_funk_ylo_ng_name",
|
||||
}
|
||||
|
||||
for in, expected := range testCases {
|
||||
out := gonicCasedName(in)
|
||||
if out != expected {
|
||||
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGonicMapperToObj(t *testing.T) {
|
||||
testCases := map[string]string{
|
||||
"http_lib": "HTTPLib",
|
||||
"id": "ID",
|
||||
"ida": "Ida",
|
||||
"id_aa": "IDAa",
|
||||
"aa_id": "AaID",
|
||||
"my_r_eal_funk_ylo_ng_name": "MyREalFunkYloNgName",
|
||||
}
|
||||
|
||||
for in, expected := range testCases {
|
||||
out := LintGonicMapper.Table2Obj(in)
|
||||
if out != expected {
|
||||
t.Errorf("Given %s, expected %s but got %s", in, expected, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
)
|
||||
|
||||
type PK []interface{}
|
||||
|
||||
func NewPK(pks ...interface{}) *PK {
|
||||
p := PK(pks)
|
||||
return &p
|
||||
}
|
||||
|
||||
func (p *PK) ToString() (string, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(*p)
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
func (p *PK) FromString(content string) error {
|
||||
dec := gob.NewDecoder(bytes.NewBufferString(content))
|
||||
err := dec.Decode(p)
|
||||
return err
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPK(t *testing.T) {
|
||||
p := NewPK(1, 3, "string")
|
||||
str, err := p.ToString()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(str)
|
||||
|
||||
s := &PK{}
|
||||
err = s.FromString(str)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(s)
|
||||
|
||||
if len(*p) != len(*s) {
|
||||
t.Fatal("p", *p, "should be equal", *s)
|
||||
}
|
||||
|
||||
for i, ori := range *p {
|
||||
if ori != (*s)[i] {
|
||||
t.Fatal("ori", ori, reflect.ValueOf(ori), "should be equal", (*s)[i], reflect.ValueOf((*s)[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,380 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Rows struct {
|
||||
*sql.Rows
|
||||
Mapper IMapper
|
||||
}
|
||||
|
||||
func (rs *Rows) ToMapString() ([]map[string]string, error) {
|
||||
cols, err := rs.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results = make([]map[string]string, 0, 10)
|
||||
for rs.Next() {
|
||||
var record = make(map[string]string, len(cols))
|
||||
err = rs.ScanMap(&record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, record)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// scan data to a struct's pointer according field index
|
||||
func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
|
||||
if len(dest) == 0 {
|
||||
return errors.New("at least one struct")
|
||||
}
|
||||
|
||||
vvvs := make([]reflect.Value, len(dest))
|
||||
for i, s := range dest {
|
||||
vv := reflect.ValueOf(s)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return errors.New("dest should be a struct's pointer")
|
||||
}
|
||||
|
||||
vvvs[i] = vv.Elem()
|
||||
}
|
||||
|
||||
cols, err := rs.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newDest := make([]interface{}, len(cols))
|
||||
|
||||
var i = 0
|
||||
for _, vvv := range vvvs {
|
||||
for j := 0; j < vvv.NumField(); j++ {
|
||||
newDest[i] = vvv.Field(j).Addr().Interface()
|
||||
i = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
return rs.Rows.Scan(newDest...)
|
||||
}
|
||||
|
||||
var (
|
||||
fieldCache = make(map[reflect.Type]map[string]int)
|
||||
fieldCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func fieldByName(v reflect.Value, name string) reflect.Value {
|
||||
t := v.Type()
|
||||
fieldCacheMutex.RLock()
|
||||
cache, ok := fieldCache[t]
|
||||
fieldCacheMutex.RUnlock()
|
||||
if !ok {
|
||||
cache = make(map[string]int)
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
cache[t.Field(i).Name] = i
|
||||
}
|
||||
fieldCacheMutex.Lock()
|
||||
fieldCache[t] = cache
|
||||
fieldCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
if i, ok := cache[name]; ok {
|
||||
return v.Field(i)
|
||||
}
|
||||
|
||||
return reflect.Zero(t)
|
||||
}
|
||||
|
||||
// scan data to a struct's pointer according field name
|
||||
func (rs *Rows) ScanStructByName(dest interface{}) error {
|
||||
vv := reflect.ValueOf(dest)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
|
||||
return errors.New("dest should be a struct's pointer")
|
||||
}
|
||||
|
||||
cols, err := rs.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newDest := make([]interface{}, len(cols))
|
||||
var v EmptyScanner
|
||||
for j, name := range cols {
|
||||
f := fieldByName(vv.Elem(), rs.Mapper.Table2Obj(name))
|
||||
if f.IsValid() {
|
||||
newDest[j] = f.Addr().Interface()
|
||||
} else {
|
||||
newDest[j] = &v
|
||||
}
|
||||
}
|
||||
|
||||
return rs.Rows.Scan(newDest...)
|
||||
}
|
||||
|
||||
type cacheStruct struct {
|
||||
value reflect.Value
|
||||
idx int
|
||||
}
|
||||
|
||||
var (
|
||||
reflectCache = make(map[reflect.Type]*cacheStruct)
|
||||
reflectCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func ReflectNew(typ reflect.Type) reflect.Value {
|
||||
reflectCacheMutex.RLock()
|
||||
cs, ok := reflectCache[typ]
|
||||
reflectCacheMutex.RUnlock()
|
||||
|
||||
const newSize = 200
|
||||
|
||||
if !ok || cs.idx+1 > newSize-1 {
|
||||
cs = &cacheStruct{reflect.MakeSlice(reflect.SliceOf(typ), newSize, newSize), 0}
|
||||
reflectCacheMutex.Lock()
|
||||
reflectCache[typ] = cs
|
||||
reflectCacheMutex.Unlock()
|
||||
} else {
|
||||
reflectCacheMutex.Lock()
|
||||
cs.idx = cs.idx + 1
|
||||
reflectCacheMutex.Unlock()
|
||||
}
|
||||
return cs.value.Index(cs.idx).Addr()
|
||||
}
|
||||
|
||||
// scan data to a slice's pointer, slice's length should equal to columns' number
|
||||
func (rs *Rows) ScanSlice(dest interface{}) error {
|
||||
vv := reflect.ValueOf(dest)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
|
||||
return errors.New("dest should be a slice's pointer")
|
||||
}
|
||||
|
||||
vvv := vv.Elem()
|
||||
cols, err := rs.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newDest := make([]interface{}, len(cols))
|
||||
|
||||
for j := 0; j < len(cols); j++ {
|
||||
if j >= vvv.Len() {
|
||||
newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
|
||||
} else {
|
||||
newDest[j] = vvv.Index(j).Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
err = rs.Rows.Scan(newDest...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srcLen := vvv.Len()
|
||||
for i := srcLen; i < len(cols); i++ {
|
||||
vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scan data to a map's pointer
|
||||
func (rs *Rows) ScanMap(dest interface{}) error {
|
||||
vv := reflect.ValueOf(dest)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return errors.New("dest should be a map's pointer")
|
||||
}
|
||||
|
||||
cols, err := rs.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newDest := make([]interface{}, len(cols))
|
||||
vvv := vv.Elem()
|
||||
|
||||
for i, _ := range cols {
|
||||
newDest[i] = ReflectNew(vvv.Type().Elem()).Interface()
|
||||
//v := reflect.New(vvv.Type().Elem())
|
||||
//newDest[i] = v.Interface()
|
||||
}
|
||||
|
||||
err = rs.Rows.Scan(newDest...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, name := range cols {
|
||||
vname := reflect.ValueOf(name)
|
||||
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
/*func (rs *Rows) ScanMap(dest interface{}) error {
|
||||
vv := reflect.ValueOf(dest)
|
||||
if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
|
||||
return errors.New("dest should be a map's pointer")
|
||||
}
|
||||
|
||||
cols, err := rs.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newDest := make([]interface{}, len(cols))
|
||||
err = rs.ScanSlice(newDest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vvv := vv.Elem()
|
||||
|
||||
for i, name := range cols {
|
||||
vname := reflect.ValueOf(name)
|
||||
vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
|
||||
}
|
||||
|
||||
return nil
|
||||
}*/
|
||||
type Row struct {
|
||||
rows *Rows
|
||||
// One of these two will be non-nil:
|
||||
err error // deferred error for easy chaining
|
||||
}
|
||||
|
||||
func (row *Row) Columns() ([]string, error) {
|
||||
if row.err != nil {
|
||||
return nil, row.err
|
||||
}
|
||||
return row.rows.Columns()
|
||||
}
|
||||
|
||||
func (row *Row) Scan(dest ...interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
defer row.rows.Close()
|
||||
|
||||
for _, dp := range dest {
|
||||
if _, ok := dp.(*sql.RawBytes); ok {
|
||||
return errors.New("sql: RawBytes isn't allowed on Row.Scan")
|
||||
}
|
||||
}
|
||||
|
||||
if !row.rows.Next() {
|
||||
if err := row.rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
err := row.rows.Scan(dest...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Make sure the query can be processed to completion with no errors.
|
||||
return row.rows.Close()
|
||||
}
|
||||
|
||||
func (row *Row) ScanStructByName(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
defer row.rows.Close()
|
||||
|
||||
if !row.rows.Next() {
|
||||
if err := row.rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
err := row.rows.ScanStructByName(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Make sure the query can be processed to completion with no errors.
|
||||
return row.rows.Close()
|
||||
}
|
||||
|
||||
func (row *Row) ScanStructByIndex(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
defer row.rows.Close()
|
||||
|
||||
if !row.rows.Next() {
|
||||
if err := row.rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
err := row.rows.ScanStructByIndex(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Make sure the query can be processed to completion with no errors.
|
||||
return row.rows.Close()
|
||||
}
|
||||
|
||||
// scan data to a slice's pointer, slice's length should equal to columns' number
|
||||
func (row *Row) ScanSlice(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
defer row.rows.Close()
|
||||
|
||||
if !row.rows.Next() {
|
||||
if err := row.rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
err := row.rows.ScanSlice(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the query can be processed to completion with no errors.
|
||||
return row.rows.Close()
|
||||
}
|
||||
|
||||
// scan data to a map's pointer
|
||||
func (row *Row) ScanMap(dest interface{}) error {
|
||||
if row.err != nil {
|
||||
return row.err
|
||||
}
|
||||
defer row.rows.Close()
|
||||
|
||||
if !row.rows.Next() {
|
||||
if err := row.rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
err := row.rows.ScanMap(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Make sure the query can be processed to completion with no errors.
|
||||
return row.rows.Close()
|
||||
}
|
||||
|
||||
func (row *Row) ToMapString() (map[string]string, error) {
|
||||
cols, err := row.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var record = make(map[string]string, len(cols))
|
||||
err = row.ScanMap(&record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NullTime time.Time
|
||||
|
||||
var (
|
||||
_ driver.Valuer = NullTime{}
|
||||
)
|
||||
|
||||
func (ns *NullTime) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
return convertTime(ns, value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (ns NullTime) Value() (driver.Value, error) {
|
||||
if (time.Time)(ns).IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil
|
||||
}
|
||||
|
||||
func convertTime(dest *NullTime, src interface{}) error {
|
||||
// Common cases, without reflect.
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
t, err := time.Parse("2006-01-02 15:04:05", s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dest = NullTime(t)
|
||||
return nil
|
||||
case []uint8:
|
||||
t, err := time.Parse("2006-01-02 15:04:05", string(s))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*dest = NullTime(t)
|
||||
return nil
|
||||
case nil:
|
||||
default:
|
||||
return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// database table
|
||||
type Table struct {
|
||||
Name string
|
||||
Type reflect.Type
|
||||
columnsSeq []string
|
||||
columnsMap map[string][]*Column
|
||||
columns []*Column
|
||||
Indexes map[string]*Index
|
||||
PrimaryKeys []string
|
||||
AutoIncrement string
|
||||
Created map[string]bool
|
||||
Updated string
|
||||
Deleted string
|
||||
Version string
|
||||
Cacher Cacher
|
||||
StoreEngine string
|
||||
Charset string
|
||||
}
|
||||
|
||||
func (table *Table) Columns() []*Column {
|
||||
return table.columns
|
||||
}
|
||||
|
||||
func (table *Table) ColumnsSeq() []string {
|
||||
return table.columnsSeq
|
||||
}
|
||||
|
||||
func NewEmptyTable() *Table {
|
||||
return NewTable("", nil)
|
||||
}
|
||||
|
||||
func NewTable(name string, t reflect.Type) *Table {
|
||||
return &Table{Name: name, Type: t,
|
||||
columnsSeq: make([]string, 0),
|
||||
columns: make([]*Column, 0),
|
||||
columnsMap: make(map[string][]*Column),
|
||||
Indexes: make(map[string]*Index),
|
||||
Created: make(map[string]bool),
|
||||
PrimaryKeys: make([]string, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (table *Table) columnsByName(name string) []*Column {
|
||||
|
||||
n := len(name)
|
||||
|
||||
for k := range table.columnsMap {
|
||||
if len(k) != n {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(k, name) {
|
||||
return table.columnsMap[k]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (table *Table) GetColumn(name string) *Column {
|
||||
|
||||
cols := table.columnsByName(name)
|
||||
|
||||
if cols != nil {
|
||||
return cols[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (table *Table) GetColumnIdx(name string, idx int) *Column {
|
||||
|
||||
cols := table.columnsByName(name)
|
||||
|
||||
if cols != nil && idx < len(cols) {
|
||||
return cols[idx]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// if has primary key, return column
|
||||
func (table *Table) PKColumns() []*Column {
|
||||
columns := make([]*Column, len(table.PrimaryKeys))
|
||||
for i, name := range table.PrimaryKeys {
|
||||
columns[i] = table.GetColumn(name)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (table *Table) ColumnType(name string) reflect.Type {
|
||||
t, _ := table.Type.FieldByName(name)
|
||||
return t.Type
|
||||
}
|
||||
|
||||
func (table *Table) AutoIncrColumn() *Column {
|
||||
return table.GetColumn(table.AutoIncrement)
|
||||
}
|
||||
|
||||
func (table *Table) VersionColumn() *Column {
|
||||
return table.GetColumn(table.Version)
|
||||
}
|
||||
|
||||
func (table *Table) UpdatedColumn() *Column {
|
||||
return table.GetColumn(table.Updated)
|
||||
}
|
||||
|
||||
func (table *Table) DeletedColumn() *Column {
|
||||
return table.GetColumn(table.Deleted)
|
||||
}
|
||||
|
||||
// add a column to table
|
||||
func (table *Table) AddColumn(col *Column) {
|
||||
table.columnsSeq = append(table.columnsSeq, col.Name)
|
||||
table.columns = append(table.columns, col)
|
||||
colName := strings.ToLower(col.Name)
|
||||
if c, ok := table.columnsMap[colName]; ok {
|
||||
table.columnsMap[colName] = append(c, col)
|
||||
} else {
|
||||
table.columnsMap[colName] = []*Column{col}
|
||||
}
|
||||
|
||||
if col.IsPrimaryKey {
|
||||
table.PrimaryKeys = append(table.PrimaryKeys, col.Name)
|
||||
}
|
||||
if col.IsAutoIncrement {
|
||||
table.AutoIncrement = col.Name
|
||||
}
|
||||
if col.IsCreated {
|
||||
table.Created[col.Name] = true
|
||||
}
|
||||
if col.IsUpdated {
|
||||
table.Updated = col.Name
|
||||
}
|
||||
if col.IsDeleted {
|
||||
table.Deleted = col.Name
|
||||
}
|
||||
if col.IsVersion {
|
||||
table.Version = col.Name
|
||||
}
|
||||
}
|
||||
|
||||
// add an index or an unique to table
|
||||
func (table *Table) AddIndex(index *Index) {
|
||||
table.Indexes[index.Name] = index
|
||||
}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testsGetColumn = []struct {
|
||||
name string
|
||||
idx int
|
||||
}{
|
||||
{"Id", 0},
|
||||
{"Deleted", 0},
|
||||
{"Caption", 0},
|
||||
{"Code_1", 0},
|
||||
{"Code_2", 0},
|
||||
{"Code_3", 0},
|
||||
{"Parent_Id", 0},
|
||||
{"Latitude", 0},
|
||||
{"Longitude", 0},
|
||||
}
|
||||
|
||||
var table *Table
|
||||
|
||||
func init() {
|
||||
|
||||
table = NewEmptyTable()
|
||||
|
||||
var name string
|
||||
|
||||
for i := 0; i < len(testsGetColumn); i++ {
|
||||
// as in Table.AddColumn func
|
||||
name = strings.ToLower(testsGetColumn[i].name)
|
||||
|
||||
table.columnsMap[name] = append(table.columnsMap[name], &Column{})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumn(t *testing.T) {
|
||||
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumn(test.name) == nil {
|
||||
t.Error("Column not found!")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumnIdx(t *testing.T) {
|
||||
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumnIdx(test.name, test.idx) == nil {
|
||||
t.Errorf("Column %s with idx %d not found!", test.name, test.idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetColumnWithToLower(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
|
||||
if _, ok := table.columnsMap[strings.ToLower(test.name)]; !ok {
|
||||
b.Errorf("Column not found:%s", test.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetColumnIdxWithToLower(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
|
||||
if c, ok := table.columnsMap[strings.ToLower(test.name)]; ok {
|
||||
if test.idx < len(c) {
|
||||
continue
|
||||
} else {
|
||||
b.Errorf("Bad idx in: %s, %d", test.name, test.idx)
|
||||
}
|
||||
} else {
|
||||
b.Errorf("Column not found: %s, %d", test.name, test.idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetColumn(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumn(test.name) == nil {
|
||||
b.Errorf("Column not found:%s", test.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetColumnIdx(b *testing.B) {
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, test := range testsGetColumn {
|
||||
if table.GetColumnIdx(test.name, test.idx) == nil {
|
||||
b.Errorf("Column not found:%s, %d", test.name, test.idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
package core
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
POSTGRES = "postgres"
|
||||
SQLITE = "sqlite3"
|
||||
MYSQL = "mysql"
|
||||
MSSQL = "mssql"
|
||||
ORACLE = "oracle"
|
||||
)
|
||||
|
||||
// xorm SQL types
|
||||
type SQLType struct {
|
||||
Name string
|
||||
DefaultLength int
|
||||
DefaultLength2 int
|
||||
}
|
||||
|
||||
const (
|
||||
UNKNOW_TYPE = iota
|
||||
TEXT_TYPE
|
||||
BLOB_TYPE
|
||||
TIME_TYPE
|
||||
NUMERIC_TYPE
|
||||
)
|
||||
|
||||
func (s *SQLType) IsType(st int) bool {
|
||||
if t, ok := SqlTypes[s.Name]; ok && t == st {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SQLType) IsText() bool {
|
||||
return s.IsType(TEXT_TYPE)
|
||||
}
|
||||
|
||||
func (s *SQLType) IsBlob() bool {
|
||||
return s.IsType(BLOB_TYPE)
|
||||
}
|
||||
|
||||
func (s *SQLType) IsTime() bool {
|
||||
return s.IsType(TIME_TYPE)
|
||||
}
|
||||
|
||||
func (s *SQLType) IsNumeric() bool {
|
||||
return s.IsType(NUMERIC_TYPE)
|
||||
}
|
||||
|
||||
func (s *SQLType) IsJson() bool {
|
||||
return s.Name == Json || s.Name == Jsonb
|
||||
}
|
||||
|
||||
var (
|
||||
Bit = "BIT"
|
||||
TinyInt = "TINYINT"
|
||||
SmallInt = "SMALLINT"
|
||||
MediumInt = "MEDIUMINT"
|
||||
Int = "INT"
|
||||
Integer = "INTEGER"
|
||||
BigInt = "BIGINT"
|
||||
|
||||
Enum = "ENUM"
|
||||
Set = "SET"
|
||||
|
||||
Char = "CHAR"
|
||||
Varchar = "VARCHAR"
|
||||
NVarchar = "NVARCHAR"
|
||||
TinyText = "TINYTEXT"
|
||||
Text = "TEXT"
|
||||
Clob = "CLOB"
|
||||
MediumText = "MEDIUMTEXT"
|
||||
LongText = "LONGTEXT"
|
||||
Uuid = "UUID"
|
||||
|
||||
Date = "DATE"
|
||||
DateTime = "DATETIME"
|
||||
Time = "TIME"
|
||||
TimeStamp = "TIMESTAMP"
|
||||
TimeStampz = "TIMESTAMPZ"
|
||||
|
||||
Decimal = "DECIMAL"
|
||||
Numeric = "NUMERIC"
|
||||
|
||||
Real = "REAL"
|
||||
Float = "FLOAT"
|
||||
Double = "DOUBLE"
|
||||
|
||||
Binary = "BINARY"
|
||||
VarBinary = "VARBINARY"
|
||||
TinyBlob = "TINYBLOB"
|
||||
Blob = "BLOB"
|
||||
MediumBlob = "MEDIUMBLOB"
|
||||
LongBlob = "LONGBLOB"
|
||||
Bytea = "BYTEA"
|
||||
|
||||
Bool = "BOOL"
|
||||
|
||||
Serial = "SERIAL"
|
||||
BigSerial = "BIGSERIAL"
|
||||
|
||||
Json = "JSON"
|
||||
Jsonb = "JSONB"
|
||||
|
||||
SqlTypes = map[string]int{
|
||||
Bit: NUMERIC_TYPE,
|
||||
TinyInt: NUMERIC_TYPE,
|
||||
SmallInt: NUMERIC_TYPE,
|
||||
MediumInt: NUMERIC_TYPE,
|
||||
Int: NUMERIC_TYPE,
|
||||
Integer: NUMERIC_TYPE,
|
||||
BigInt: NUMERIC_TYPE,
|
||||
|
||||
Enum: TEXT_TYPE,
|
||||
Set: TEXT_TYPE,
|
||||
Json: TEXT_TYPE,
|
||||
Jsonb: TEXT_TYPE,
|
||||
|
||||
Char: TEXT_TYPE,
|
||||
Varchar: TEXT_TYPE,
|
||||
NVarchar: TEXT_TYPE,
|
||||
TinyText: TEXT_TYPE,
|
||||
Text: TEXT_TYPE,
|
||||
MediumText: TEXT_TYPE,
|
||||
LongText: TEXT_TYPE,
|
||||
Uuid: TEXT_TYPE,
|
||||
Clob: TEXT_TYPE,
|
||||
|
||||
Date: TIME_TYPE,
|
||||
DateTime: TIME_TYPE,
|
||||
Time: TIME_TYPE,
|
||||
TimeStamp: TIME_TYPE,
|
||||
TimeStampz: TIME_TYPE,
|
||||
|
||||
Decimal: NUMERIC_TYPE,
|
||||
Numeric: NUMERIC_TYPE,
|
||||
Real: NUMERIC_TYPE,
|
||||
Float: NUMERIC_TYPE,
|
||||
Double: NUMERIC_TYPE,
|
||||
|
||||
Binary: BLOB_TYPE,
|
||||
VarBinary: BLOB_TYPE,
|
||||
|
||||
TinyBlob: BLOB_TYPE,
|
||||
Blob: BLOB_TYPE,
|
||||
MediumBlob: BLOB_TYPE,
|
||||
LongBlob: BLOB_TYPE,
|
||||
Bytea: BLOB_TYPE,
|
||||
|
||||
Bool: NUMERIC_TYPE,
|
||||
|
||||
Serial: NUMERIC_TYPE,
|
||||
BigSerial: NUMERIC_TYPE,
|
||||
}
|
||||
|
||||
intTypes = sort.StringSlice{"*int", "*int16", "*int32", "*int8"}
|
||||
uintTypes = sort.StringSlice{"*uint", "*uint16", "*uint32", "*uint8"}
|
||||
)
|
||||
|
||||
// !nashtsai! treat following var as interal const values, these are used for reflect.TypeOf comparision
|
||||
var (
|
||||
c_EMPTY_STRING string
|
||||
c_BOOL_DEFAULT bool
|
||||
c_BYTE_DEFAULT byte
|
||||
c_COMPLEX64_DEFAULT complex64
|
||||
c_COMPLEX128_DEFAULT complex128
|
||||
c_FLOAT32_DEFAULT float32
|
||||
c_FLOAT64_DEFAULT float64
|
||||
c_INT64_DEFAULT int64
|
||||
c_UINT64_DEFAULT uint64
|
||||
c_INT32_DEFAULT int32
|
||||
c_UINT32_DEFAULT uint32
|
||||
c_INT16_DEFAULT int16
|
||||
c_UINT16_DEFAULT uint16
|
||||
c_INT8_DEFAULT int8
|
||||
c_UINT8_DEFAULT uint8
|
||||
c_INT_DEFAULT int
|
||||
c_UINT_DEFAULT uint
|
||||
c_TIME_DEFAULT time.Time
|
||||
)
|
||||
|
||||
var (
|
||||
IntType = reflect.TypeOf(c_INT_DEFAULT)
|
||||
Int8Type = reflect.TypeOf(c_INT8_DEFAULT)
|
||||
Int16Type = reflect.TypeOf(c_INT16_DEFAULT)
|
||||
Int32Type = reflect.TypeOf(c_INT32_DEFAULT)
|
||||
Int64Type = reflect.TypeOf(c_INT64_DEFAULT)
|
||||
|
||||
UintType = reflect.TypeOf(c_UINT_DEFAULT)
|
||||
Uint8Type = reflect.TypeOf(c_UINT8_DEFAULT)
|
||||
Uint16Type = reflect.TypeOf(c_UINT16_DEFAULT)
|
||||
Uint32Type = reflect.TypeOf(c_UINT32_DEFAULT)
|
||||
Uint64Type = reflect.TypeOf(c_UINT64_DEFAULT)
|
||||
|
||||
Float32Type = reflect.TypeOf(c_FLOAT32_DEFAULT)
|
||||
Float64Type = reflect.TypeOf(c_FLOAT64_DEFAULT)
|
||||
|
||||
Complex64Type = reflect.TypeOf(c_COMPLEX64_DEFAULT)
|
||||
Complex128Type = reflect.TypeOf(c_COMPLEX128_DEFAULT)
|
||||
|
||||
StringType = reflect.TypeOf(c_EMPTY_STRING)
|
||||
BoolType = reflect.TypeOf(c_BOOL_DEFAULT)
|
||||
ByteType = reflect.TypeOf(c_BYTE_DEFAULT)
|
||||
BytesType = reflect.SliceOf(ByteType)
|
||||
|
||||
TimeType = reflect.TypeOf(c_TIME_DEFAULT)
|
||||
)
|
||||
|
||||
var (
|
||||
PtrIntType = reflect.PtrTo(IntType)
|
||||
PtrInt8Type = reflect.PtrTo(Int8Type)
|
||||
PtrInt16Type = reflect.PtrTo(Int16Type)
|
||||
PtrInt32Type = reflect.PtrTo(Int32Type)
|
||||
PtrInt64Type = reflect.PtrTo(Int64Type)
|
||||
|
||||
PtrUintType = reflect.PtrTo(UintType)
|
||||
PtrUint8Type = reflect.PtrTo(Uint8Type)
|
||||
PtrUint16Type = reflect.PtrTo(Uint16Type)
|
||||
PtrUint32Type = reflect.PtrTo(Uint32Type)
|
||||
PtrUint64Type = reflect.PtrTo(Uint64Type)
|
||||
|
||||
PtrFloat32Type = reflect.PtrTo(Float32Type)
|
||||
PtrFloat64Type = reflect.PtrTo(Float64Type)
|
||||
|
||||
PtrComplex64Type = reflect.PtrTo(Complex64Type)
|
||||
PtrComplex128Type = reflect.PtrTo(Complex128Type)
|
||||
|
||||
PtrStringType = reflect.PtrTo(StringType)
|
||||
PtrBoolType = reflect.PtrTo(BoolType)
|
||||
PtrByteType = reflect.PtrTo(ByteType)
|
||||
|
||||
PtrTimeType = reflect.PtrTo(TimeType)
|
||||
)
|
||||
|
||||
// Type2SQLType generate SQLType acorrding Go's type
|
||||
func Type2SQLType(t reflect.Type) (st SQLType) {
|
||||
switch k := t.Kind(); k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
||||
st = SQLType{Int, 0, 0}
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
st = SQLType{BigInt, 0, 0}
|
||||
case reflect.Float32:
|
||||
st = SQLType{Float, 0, 0}
|
||||
case reflect.Float64:
|
||||
st = SQLType{Double, 0, 0}
|
||||
case reflect.Complex64, reflect.Complex128:
|
||||
st = SQLType{Varchar, 64, 0}
|
||||
case reflect.Array, reflect.Slice, reflect.Map:
|
||||
if t.Elem() == reflect.TypeOf(c_BYTE_DEFAULT) {
|
||||
st = SQLType{Blob, 0, 0}
|
||||
} else {
|
||||
st = SQLType{Text, 0, 0}
|
||||
}
|
||||
case reflect.Bool:
|
||||
st = SQLType{Bool, 0, 0}
|
||||
case reflect.String:
|
||||
st = SQLType{Varchar, 255, 0}
|
||||
case reflect.Struct:
|
||||
if t.ConvertibleTo(TimeType) {
|
||||
st = SQLType{DateTime, 0, 0}
|
||||
} else {
|
||||
// TODO need to handle association struct
|
||||
st = SQLType{Text, 0, 0}
|
||||
}
|
||||
case reflect.Ptr:
|
||||
st = Type2SQLType(t.Elem())
|
||||
default:
|
||||
st = SQLType{Text, 0, 0}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// default sql type change to go types
|
||||
func SQLType2Type(st SQLType) reflect.Type {
|
||||
name := strings.ToUpper(st.Name)
|
||||
switch name {
|
||||
case Bit, TinyInt, SmallInt, MediumInt, Int, Integer, Serial:
|
||||
return reflect.TypeOf(1)
|
||||
case BigInt, BigSerial:
|
||||
return reflect.TypeOf(int64(1))
|
||||
case Float, Real:
|
||||
return reflect.TypeOf(float32(1))
|
||||
case Double:
|
||||
return reflect.TypeOf(float64(1))
|
||||
case Char, Varchar, NVarchar, TinyText, Text, MediumText, LongText, Enum, Set, Uuid, Clob:
|
||||
return reflect.TypeOf("")
|
||||
case TinyBlob, Blob, LongBlob, Bytea, Binary, MediumBlob, VarBinary:
|
||||
return reflect.TypeOf([]byte{})
|
||||
case Bool:
|
||||
return reflect.TypeOf(true)
|
||||
case DateTime, Date, Time, TimeStamp, TimeStampz:
|
||||
return reflect.TypeOf(c_TIME_DEFAULT)
|
||||
case Decimal, Numeric:
|
||||
return reflect.TypeOf("")
|
||||
default:
|
||||
return reflect.TypeOf("")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
## Contributing to pq
|
||||
|
||||
`pq` has a backlog of pull requests, but contributions are still very
|
||||
much welcome. You can help with patch review, submitting bug reports,
|
||||
or adding new functionality. There is no formal style guide, but
|
||||
please conform to the style of existing code and general Go formatting
|
||||
conventions when submitting patches.
|
||||
|
||||
### Patch review
|
||||
|
||||
Help review existing open pull requests by commenting on the code or
|
||||
proposed functionality.
|
||||
|
||||
### Bug reports
|
||||
|
||||
We appreciate any bug reports, but especially ones with self-contained
|
||||
(doesn't depend on code outside of pq), minimal (can't be simplified
|
||||
further) test cases. It's especially helpful if you can submit a pull
|
||||
request with just the failing test case (you'll probably want to
|
||||
pattern it after the tests in
|
||||
[conn_test.go](https://github.com/lib/pq/blob/master/conn_test.go).
|
||||
|
||||
### New functionality
|
||||
|
||||
There are a number of pending patches for new functionality, so
|
||||
additional feature patches will take a while to merge. Still, patches
|
||||
are generally reviewed based on usefulness and complexity in addition
|
||||
to time-in-queue, so if you have a knockout idea, take a shot. Feel
|
||||
free to open an issue discussion your proposed patch beforehand.
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
Copyright (c) 2011-2013, 'pq' Contributors
|
||||
Portions Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
# pq - A pure Go postgres driver for Go's database/sql package
|
||||
|
||||
[](https://travis-ci.org/lib/pq)
|
||||
|
||||
## Install
|
||||
|
||||
go get github.com/lib/pq
|
||||
|
||||
## Docs
|
||||
|
||||
For detailed documentation and basic usage examples, please see the package
|
||||
documentation at <http://godoc.org/github.com/lib/pq>.
|
||||
|
||||
## Tests
|
||||
|
||||
`go test` is used for testing. A running PostgreSQL server is
|
||||
required, with the ability to log in. The default database to connect
|
||||
to test with is "pqgotest," but it can be overridden using environment
|
||||
variables.
|
||||
|
||||
Example:
|
||||
|
||||
PGHOST=/run/postgresql go test github.com/lib/pq
|
||||
|
||||
Optionally, a benchmark suite can be run as part of the tests:
|
||||
|
||||
PGHOST=/run/postgresql go test -bench .
|
||||
|
||||
## Features
|
||||
|
||||
* SSL
|
||||
* Handles bad connections for `database/sql`
|
||||
* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`)
|
||||
* Scan binary blobs correctly (i.e. `bytea`)
|
||||
* Package for `hstore` support
|
||||
* COPY FROM support
|
||||
* pq.ParseURL for converting urls to connection strings for sql.Open.
|
||||
* Many libpq compatible environment variables
|
||||
* Unix socket support
|
||||
* Notifications: `LISTEN`/`NOTIFY`
|
||||
* pgpass support
|
||||
|
||||
## Future / Things you can help with
|
||||
|
||||
* Better COPY FROM / COPY TO (see discussion in #181)
|
||||
|
||||
## Thank you (alphabetical)
|
||||
|
||||
Some of these contributors are from the original library `bmizerany/pq.go` whose
|
||||
code still exists in here.
|
||||
|
||||
* Andy Balholm (andybalholm)
|
||||
* Ben Berkert (benburkert)
|
||||
* Benjamin Heatwole (bheatwole)
|
||||
* Bill Mill (llimllib)
|
||||
* Bjørn Madsen (aeons)
|
||||
* Blake Gentry (bgentry)
|
||||
* Brad Fitzpatrick (bradfitz)
|
||||
* Charlie Melbye (cmelbye)
|
||||
* Chris Bandy (cbandy)
|
||||
* Chris Gilling (cgilling)
|
||||
* Chris Walsh (cwds)
|
||||
* Dan Sosedoff (sosedoff)
|
||||
* Daniel Farina (fdr)
|
||||
* Eric Chlebek (echlebek)
|
||||
* Eric Garrido (minusnine)
|
||||
* Eric Urban (hydrogen18)
|
||||
* Everyone at The Go Team
|
||||
* Evan Shaw (edsrzf)
|
||||
* Ewan Chou (coocood)
|
||||
* Fazal Majid (fazalmajid)
|
||||
* Federico Romero (federomero)
|
||||
* Fumin (fumin)
|
||||
* Gary Burd (garyburd)
|
||||
* Heroku (heroku)
|
||||
* James Pozdena (jpoz)
|
||||
* Jason McVetta (jmcvetta)
|
||||
* Jeremy Jay (pbnjay)
|
||||
* Joakim Sernbrant (serbaut)
|
||||
* John Gallagher (jgallagher)
|
||||
* Jonathan Rudenberg (titanous)
|
||||
* Joël Stemmer (jstemmer)
|
||||
* Kamil Kisiel (kisielk)
|
||||
* Kelly Dunn (kellydunn)
|
||||
* Keith Rarick (kr)
|
||||
* Kir Shatrov (kirs)
|
||||
* Lann Martin (lann)
|
||||
* Maciek Sakrejda (uhoh-itsmaciek)
|
||||
* Marc Brinkmann (mbr)
|
||||
* Marko Tiikkaja (johto)
|
||||
* Matt Newberry (MattNewberry)
|
||||
* Matt Robenolt (mattrobenolt)
|
||||
* Martin Olsen (martinolsen)
|
||||
* Mike Lewis (mikelikespie)
|
||||
* Nicolas Patry (Narsil)
|
||||
* Oliver Tonnhofer (olt)
|
||||
* Patrick Hayes (phayes)
|
||||
* Paul Hammond (paulhammond)
|
||||
* Ryan Smith (ryandotsmith)
|
||||
* Samuel Stauffer (samuel)
|
||||
* Timothée Peignier (cyberdelia)
|
||||
* Travis Cline (tmc)
|
||||
* TruongSinh Tran-Nguyen (truongsinh)
|
||||
* Yaismel Miranda (ympons)
|
||||
* notedit (notedit)
|
||||
|
|
@ -0,0 +1,727 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var typeByteSlice = reflect.TypeOf([]byte{})
|
||||
var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
|
||||
|
||||
// Array returns the optimal driver.Valuer and sql.Scanner for an array or
|
||||
// slice of any dimension.
|
||||
//
|
||||
// For example:
|
||||
// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
|
||||
//
|
||||
// var x []sql.NullInt64
|
||||
// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x))
|
||||
//
|
||||
// Scanning multi-dimensional arrays is not supported. Arrays where the lower
|
||||
// bound is not one (such as `[0:0]={1}') are not supported.
|
||||
func Array(a interface{}) interface {
|
||||
driver.Valuer
|
||||
sql.Scanner
|
||||
} {
|
||||
switch a := a.(type) {
|
||||
case []bool:
|
||||
return (*BoolArray)(&a)
|
||||
case []float64:
|
||||
return (*Float64Array)(&a)
|
||||
case []int64:
|
||||
return (*Int64Array)(&a)
|
||||
case []string:
|
||||
return (*StringArray)(&a)
|
||||
|
||||
case *[]bool:
|
||||
return (*BoolArray)(a)
|
||||
case *[]float64:
|
||||
return (*Float64Array)(a)
|
||||
case *[]int64:
|
||||
return (*Int64Array)(a)
|
||||
case *[]string:
|
||||
return (*StringArray)(a)
|
||||
}
|
||||
|
||||
return GenericArray{a}
|
||||
}
|
||||
|
||||
// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
|
||||
// to override the array delimiter used by GenericArray.
|
||||
type ArrayDelimiter interface {
|
||||
// ArrayDelimiter returns the delimiter character(s) for this element's type.
|
||||
ArrayDelimiter() string
|
||||
}
|
||||
|
||||
// BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
|
||||
type BoolArray []bool
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *BoolArray) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to BoolArray", src)
|
||||
}
|
||||
|
||||
func (a *BoolArray) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(BoolArray, len(elems))
|
||||
for i, v := range elems {
|
||||
if len(v) != 1 {
|
||||
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
|
||||
}
|
||||
switch v[0] {
|
||||
case 't':
|
||||
b[i] = true
|
||||
case 'f':
|
||||
b[i] = false
|
||||
default:
|
||||
return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a BoolArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be exactly two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1+2*n)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
b[2*i] = ','
|
||||
if a[i] {
|
||||
b[1+2*i] = 't'
|
||||
} else {
|
||||
b[1+2*i] = 'f'
|
||||
}
|
||||
}
|
||||
|
||||
b[0] = '{'
|
||||
b[2*n] = '}'
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type.
|
||||
type ByteaArray [][]byte
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *ByteaArray) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to ByteaArray", src)
|
||||
}
|
||||
|
||||
func (a *ByteaArray) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "ByteaArray")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(ByteaArray, len(elems))
|
||||
for i, v := range elems {
|
||||
b[i], err = parseBytea(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface. It uses the "hex" format which
|
||||
// is only supported on PostgreSQL 9.0 or newer.
|
||||
func (a ByteaArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, 2*N bytes of quotes,
|
||||
// 3*N bytes of hex formatting, and N-1 bytes of delimiters.
|
||||
size := 1 + 6*n
|
||||
for _, x := range a {
|
||||
size += hex.EncodedLen(len(x))
|
||||
}
|
||||
|
||||
b := make([]byte, size)
|
||||
|
||||
for i, s := 0, b; i < n; i++ {
|
||||
o := copy(s, `,"\\x`)
|
||||
o += hex.Encode(s[o:], a[i])
|
||||
s[o] = '"'
|
||||
s = s[o+1:]
|
||||
}
|
||||
|
||||
b[0] = '{'
|
||||
b[size-1] = '}'
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Float64Array represents a one-dimensional array of the PostgreSQL double
|
||||
// precision type.
|
||||
type Float64Array []float64
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Float64Array) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to Float64Array", src)
|
||||
}
|
||||
|
||||
func (a *Float64Array) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Float64Array, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Float64Array) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// GenericArray implements the driver.Valuer and sql.Scanner interfaces for
|
||||
// an array or slice of any dimension.
|
||||
type GenericArray struct{ A interface{} }
|
||||
|
||||
func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
|
||||
var assign func([]byte, reflect.Value) error
|
||||
var del = ","
|
||||
|
||||
// TODO calculate the assign function for other types
|
||||
// TODO repeat this section on the element type of arrays or slices (multidimensional)
|
||||
{
|
||||
if reflect.PtrTo(rt).Implements(typeSqlScanner) {
|
||||
// dest is always addressable because it is an element of a slice.
|
||||
assign = func(src []byte, dest reflect.Value) (err error) {
|
||||
ss := dest.Addr().Interface().(sql.Scanner)
|
||||
if src == nil {
|
||||
err = ss.Scan(nil)
|
||||
} else {
|
||||
err = ss.Scan(src)
|
||||
}
|
||||
return
|
||||
}
|
||||
goto FoundType
|
||||
}
|
||||
|
||||
assign = func([]byte, reflect.Value) error {
|
||||
return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt)
|
||||
}
|
||||
}
|
||||
|
||||
FoundType:
|
||||
|
||||
if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok {
|
||||
del = ad.ArrayDelimiter()
|
||||
}
|
||||
|
||||
return rt, assign, del
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a GenericArray) Scan(src interface{}) error {
|
||||
dpv := reflect.ValueOf(a.A)
|
||||
switch {
|
||||
case dpv.Kind() != reflect.Ptr:
|
||||
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
|
||||
case dpv.IsNil():
|
||||
return fmt.Errorf("pq: destination %T is nil", a.A)
|
||||
}
|
||||
|
||||
dv := dpv.Elem()
|
||||
switch dv.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Array:
|
||||
default:
|
||||
return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
|
||||
}
|
||||
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src, dv)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src), dv)
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type())
|
||||
}
|
||||
|
||||
func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error {
|
||||
dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
|
||||
dims, elems, err := parseArray(src, []byte(del))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO allow multidimensional
|
||||
|
||||
if len(dims) > 1 {
|
||||
return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented",
|
||||
strings.Replace(fmt.Sprint(dims), " ", "][", -1))
|
||||
}
|
||||
|
||||
// Treat a zero-dimensional array like an array with a single dimension of zero.
|
||||
if len(dims) == 0 {
|
||||
dims = append(dims, 0)
|
||||
}
|
||||
|
||||
for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
|
||||
switch rt.Kind() {
|
||||
case reflect.Slice:
|
||||
case reflect.Array:
|
||||
if rt.Len() != dims[i] {
|
||||
return fmt.Errorf("pq: cannot convert ARRAY%s to %s",
|
||||
strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
|
||||
}
|
||||
default:
|
||||
// TODO handle multidimensional
|
||||
}
|
||||
}
|
||||
|
||||
values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
|
||||
for i, e := range elems {
|
||||
if err := assign(e, values.Index(i)); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO handle multidimensional
|
||||
|
||||
switch dv.Kind() {
|
||||
case reflect.Slice:
|
||||
dv.Set(values.Slice(0, dims[0]))
|
||||
case reflect.Array:
|
||||
for i := 0; i < dims[0]; i++ {
|
||||
dv.Index(i).Set(values.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a GenericArray) Value() (driver.Value, error) {
|
||||
if a.A == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(a.A)
|
||||
|
||||
if k := rv.Kind(); k != reflect.Array && k != reflect.Slice {
|
||||
return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A)
|
||||
}
|
||||
|
||||
if n := rv.Len(); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 0, 1+2*n)
|
||||
|
||||
b, _, err := appendArray(b, rv, n)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// Int64Array represents a one-dimensional array of the PostgreSQL integer types.
|
||||
type Int64Array []int64
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *Int64Array) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to Int64Array", src)
|
||||
}
|
||||
|
||||
func (a *Int64Array) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(Int64Array, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a Int64Array) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, N bytes of values,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+2*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = strconv.AppendInt(b, a[0], 10)
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = strconv.AppendInt(b, a[i], 10)
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// StringArray represents a one-dimensional array of the PostgreSQL character types.
|
||||
type StringArray []string
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (a *StringArray) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case []byte:
|
||||
return a.scanBytes(src)
|
||||
case string:
|
||||
return a.scanBytes([]byte(src))
|
||||
}
|
||||
|
||||
return fmt.Errorf("pq: cannot convert %T to StringArray", src)
|
||||
}
|
||||
|
||||
func (a *StringArray) scanBytes(src []byte) error {
|
||||
elems, err := scanLinearArray(src, []byte{','}, "StringArray")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(elems) == 0 {
|
||||
*a = (*a)[:0]
|
||||
} else {
|
||||
b := make(StringArray, len(elems))
|
||||
for i, v := range elems {
|
||||
if b[i] = string(v); v == nil {
|
||||
return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
|
||||
}
|
||||
}
|
||||
*a = b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (a StringArray) Value() (driver.Value, error) {
|
||||
if a == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if n := len(a); n > 0 {
|
||||
// There will be at least two curly brackets, 2*N bytes of quotes,
|
||||
// and N-1 bytes of delimiters.
|
||||
b := make([]byte, 1, 1+3*n)
|
||||
b[0] = '{'
|
||||
|
||||
b = appendArrayQuotedBytes(b, []byte(a[0]))
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, ',')
|
||||
b = appendArrayQuotedBytes(b, []byte(a[i]))
|
||||
}
|
||||
|
||||
return string(append(b, '}')), nil
|
||||
}
|
||||
|
||||
return "{}", nil
|
||||
}
|
||||
|
||||
// appendArray appends rv to the buffer, returning the extended buffer and
|
||||
// the delimiter used between elements.
|
||||
//
|
||||
// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
|
||||
func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
|
||||
var del string
|
||||
var err error
|
||||
|
||||
b = append(b, '{')
|
||||
|
||||
if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
for i := 1; i < n; i++ {
|
||||
b = append(b, del...)
|
||||
if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
}
|
||||
|
||||
return append(b, '}'), del, nil
|
||||
}
|
||||
|
||||
// appendArrayElement appends rv to the buffer, returning the extended buffer
|
||||
// and the delimiter to use before the next element.
|
||||
//
|
||||
// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
|
||||
// using driver.DefaultParameterConverter and the resulting []byte or string
|
||||
// is double-quoted.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||
func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
|
||||
if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
|
||||
if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
|
||||
if n := rv.Len(); n > 0 {
|
||||
return appendArray(b, rv, n)
|
||||
}
|
||||
|
||||
return b, "", nil
|
||||
}
|
||||
}
|
||||
|
||||
var del string = ","
|
||||
var err error
|
||||
var iv interface{} = rv.Interface()
|
||||
|
||||
if ad, ok := iv.(ArrayDelimiter); ok {
|
||||
del = ad.ArrayDelimiter()
|
||||
}
|
||||
|
||||
if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
switch v := iv.(type) {
|
||||
case nil:
|
||||
return append(b, "NULL"...), del, nil
|
||||
case []byte:
|
||||
return appendArrayQuotedBytes(b, v), del, nil
|
||||
case string:
|
||||
return appendArrayQuotedBytes(b, []byte(v)), del, nil
|
||||
}
|
||||
|
||||
b, err = appendValue(b, iv)
|
||||
return b, del, err
|
||||
}
|
||||
|
||||
func appendArrayQuotedBytes(b, v []byte) []byte {
|
||||
b = append(b, '"')
|
||||
for {
|
||||
i := bytes.IndexAny(v, `"\`)
|
||||
if i < 0 {
|
||||
b = append(b, v...)
|
||||
break
|
||||
}
|
||||
if i > 0 {
|
||||
b = append(b, v[:i]...)
|
||||
}
|
||||
b = append(b, '\\', v[i])
|
||||
v = v[i+1:]
|
||||
}
|
||||
return append(b, '"')
|
||||
}
|
||||
|
||||
func appendValue(b []byte, v driver.Value) ([]byte, error) {
|
||||
return append(b, encode(nil, v, 0)...), nil
|
||||
}
|
||||
|
||||
// parseArray extracts the dimensions and elements of an array represented in
|
||||
// text format. Only representations emitted by the backend are supported.
|
||||
// Notably, whitespace around brackets and delimiters is significant, and NULL
|
||||
// is case-sensitive.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
|
||||
func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
|
||||
var depth, i int
|
||||
|
||||
if len(src) < 1 || src[0] != '{' {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0)
|
||||
}
|
||||
|
||||
Open:
|
||||
for i < len(src) {
|
||||
switch src[i] {
|
||||
case '{':
|
||||
depth++
|
||||
i++
|
||||
case '}':
|
||||
elems = make([][]byte, 0)
|
||||
goto Close
|
||||
default:
|
||||
break Open
|
||||
}
|
||||
}
|
||||
dims = make([]int, i)
|
||||
|
||||
Element:
|
||||
for i < len(src) {
|
||||
switch src[i] {
|
||||
case '{':
|
||||
depth++
|
||||
dims[depth-1] = 0
|
||||
i++
|
||||
case '"':
|
||||
var elem = []byte{}
|
||||
var escape bool
|
||||
for i++; i < len(src); i++ {
|
||||
if escape {
|
||||
elem = append(elem, src[i])
|
||||
escape = false
|
||||
} else {
|
||||
switch src[i] {
|
||||
default:
|
||||
elem = append(elem, src[i])
|
||||
case '\\':
|
||||
escape = true
|
||||
case '"':
|
||||
elems = append(elems, elem)
|
||||
i++
|
||||
break Element
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
for start := i; i < len(src); i++ {
|
||||
if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
|
||||
elem := src[start:i]
|
||||
if len(elem) == 0 {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
if bytes.Equal(elem, []byte("NULL")) {
|
||||
elem = nil
|
||||
}
|
||||
elems = append(elems, elem)
|
||||
break Element
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i < len(src) {
|
||||
if bytes.HasPrefix(src[i:], del) {
|
||||
dims[depth-1]++
|
||||
i += len(del)
|
||||
goto Element
|
||||
} else if src[i] == '}' {
|
||||
dims[depth-1]++
|
||||
depth--
|
||||
i++
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
}
|
||||
|
||||
Close:
|
||||
for i < len(src) {
|
||||
if src[i] == '}' && depth > 0 {
|
||||
depth--
|
||||
i++
|
||||
} else {
|
||||
return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
|
||||
}
|
||||
}
|
||||
if depth > 0 {
|
||||
err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i)
|
||||
}
|
||||
if err == nil {
|
||||
for _, d := range dims {
|
||||
if (len(elems) % d) != 0 {
|
||||
err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
|
||||
dims, elems, err := parseArray(src, del)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) > 1 {
|
||||
return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
|
||||
}
|
||||
return elems, err
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,435 @@
|
|||
// +build go1.1
|
||||
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
var (
|
||||
selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'"
|
||||
selectSeriesQuery = "SELECT generate_series(1, 100)"
|
||||
)
|
||||
|
||||
func BenchmarkSelectString(b *testing.B) {
|
||||
var result string
|
||||
benchQuery(b, selectStringQuery, &result)
|
||||
}
|
||||
|
||||
func BenchmarkSelectSeries(b *testing.B) {
|
||||
var result int
|
||||
benchQuery(b, selectSeriesQuery, &result)
|
||||
}
|
||||
|
||||
func benchQuery(b *testing.B, query string, result interface{}) {
|
||||
b.StopTimer()
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchQueryLoop(b, db, query, result)
|
||||
}
|
||||
}
|
||||
|
||||
func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) {
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
err = rows.Scan(result)
|
||||
if err != nil {
|
||||
b.Fatal("failed to scan", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reading from circularConn yields content[:prefixLen] once, followed by
|
||||
// content[prefixLen:] over and over again. It never returns EOF.
|
||||
type circularConn struct {
|
||||
content string
|
||||
prefixLen int
|
||||
pos int
|
||||
net.Conn // for all other net.Conn methods that will never be called
|
||||
}
|
||||
|
||||
func (r *circularConn) Read(b []byte) (n int, err error) {
|
||||
n = copy(b, r.content[r.pos:])
|
||||
r.pos += n
|
||||
if r.pos >= len(r.content) {
|
||||
r.pos = r.prefixLen
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
|
||||
func (r *circularConn) Close() error { return nil }
|
||||
|
||||
func fakeConn(content string, prefixLen int) *conn {
|
||||
c := &circularConn{content: content, prefixLen: prefixLen}
|
||||
return &conn{buf: bufio.NewReader(c), c: c}
|
||||
}
|
||||
|
||||
// This benchmark is meant to be the same as BenchmarkSelectString, but takes
|
||||
// out some of the factors this package can't control. The numbers are less noisy,
|
||||
// but also the costs of network communication aren't accurately represented.
|
||||
func BenchmarkMockSelectString(b *testing.B) {
|
||||
b.StopTimer()
|
||||
// taken from a recorded run of BenchmarkSelectString
|
||||
// See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html
|
||||
const response = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"2\x00\x00\x00\x04" +
|
||||
"D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +
|
||||
"C\x00\x00\x00\rSELECT 1\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"3\x00\x00\x00\x04" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(response, 0)
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchMockQuery(b, c, selectStringQuery)
|
||||
}
|
||||
}
|
||||
|
||||
var seriesRowData = func() string {
|
||||
var buf bytes.Buffer
|
||||
for i := 1; i <= 100; i++ {
|
||||
digits := byte(2)
|
||||
if i >= 100 {
|
||||
digits = 3
|
||||
} else if i < 10 {
|
||||
digits = 1
|
||||
}
|
||||
buf.WriteString("D\x00\x00\x00")
|
||||
buf.WriteByte(10 + digits)
|
||||
buf.WriteString("\x00\x01\x00\x00\x00")
|
||||
buf.WriteByte(digits)
|
||||
buf.WriteString(strconv.Itoa(i))
|
||||
}
|
||||
return buf.String()
|
||||
}()
|
||||
|
||||
func BenchmarkMockSelectSeries(b *testing.B) {
|
||||
b.StopTimer()
|
||||
var response = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"2\x00\x00\x00\x04" +
|
||||
seriesRowData +
|
||||
"C\x00\x00\x00\x0fSELECT 100\x00" +
|
||||
"Z\x00\x00\x00\x05I" +
|
||||
"3\x00\x00\x00\x04" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(response, 0)
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchMockQuery(b, c, selectSeriesQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func benchMockQuery(b *testing.B, c *conn, query string) {
|
||||
stmt, err := c.Prepare(query)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.Query(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var dest [1]driver.Value
|
||||
for {
|
||||
if err := rows.Next(dest[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPreparedSelectString(b *testing.B) {
|
||||
var result string
|
||||
benchPreparedQuery(b, selectStringQuery, &result)
|
||||
}
|
||||
|
||||
func BenchmarkPreparedSelectSeries(b *testing.B) {
|
||||
var result int
|
||||
benchPreparedQuery(b, selectSeriesQuery, &result)
|
||||
}
|
||||
|
||||
func benchPreparedQuery(b *testing.B, query string, result interface{}) {
|
||||
b.StopTimer()
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
stmt, err := db.Prepare(query)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchPreparedQueryLoop(b, db, stmt, result)
|
||||
}
|
||||
}
|
||||
|
||||
func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) {
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
rows.Close()
|
||||
b.Fatal("no rows")
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&result)
|
||||
if err != nil {
|
||||
b.Fatal("failed to scan")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// See the comment for BenchmarkMockSelectString.
|
||||
func BenchmarkMockPreparedSelectString(b *testing.B) {
|
||||
b.StopTimer()
|
||||
const parseResponse = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
const responses = parseResponse +
|
||||
"2\x00\x00\x00\x04" +
|
||||
"D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +
|
||||
"C\x00\x00\x00\rSELECT 1\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(responses, len(parseResponse))
|
||||
|
||||
stmt, err := c.Prepare(selectStringQuery)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchPreparedMockQuery(b, c, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMockPreparedSelectSeries(b *testing.B) {
|
||||
b.StopTimer()
|
||||
const parseResponse = "1\x00\x00\x00\x04" +
|
||||
"t\x00\x00\x00\x06\x00\x00" +
|
||||
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
var responses = parseResponse +
|
||||
"2\x00\x00\x00\x04" +
|
||||
seriesRowData +
|
||||
"C\x00\x00\x00\x0fSELECT 100\x00" +
|
||||
"Z\x00\x00\x00\x05I"
|
||||
c := fakeConn(responses, len(parseResponse))
|
||||
|
||||
stmt, err := c.Prepare(selectSeriesQuery)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchPreparedMockQuery(b, c, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) {
|
||||
rows, err := stmt.Query(nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
var dest [1]driver.Value
|
||||
for {
|
||||
if err := rows.Next(dest[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeInt64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, int64(1234), oid.T_int8)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeFloat64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, 3.14159, oid.T_float8)
|
||||
}
|
||||
}
|
||||
|
||||
var testByteString = []byte("abcdefghijklmnopqrstuvwxyz")
|
||||
|
||||
func BenchmarkEncodeByteaHex(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea)
|
||||
}
|
||||
}
|
||||
func BenchmarkEncodeByteaEscape(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeBool(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, true, oid.T_bool)
|
||||
}
|
||||
}
|
||||
|
||||
var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local)
|
||||
|
||||
func BenchmarkEncodeTimestamptz(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz)
|
||||
}
|
||||
}
|
||||
|
||||
var testIntBytes = []byte("1234")
|
||||
|
||||
func BenchmarkDecodeInt64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
var testFloatBytes = []byte("3.14159")
|
||||
|
||||
func BenchmarkDecodeFloat64(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
var testBoolBytes = []byte{'t'}
|
||||
|
||||
func BenchmarkDecodeBool(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeBool(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
rows, err := db.Query("select true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rows.Close()
|
||||
}
|
||||
|
||||
var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07")
|
||||
|
||||
func BenchmarkDecodeTimestamptz(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) {
|
||||
oldProcs := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(oldProcs)
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
globalLocationCache = newLocationCache()
|
||||
|
||||
f := func(wg *sync.WaitGroup, loops int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < loops; i++ {
|
||||
decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
|
||||
}
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
b.ResetTimer()
|
||||
for j := 0; j < 10; j++ {
|
||||
wg.Add(1)
|
||||
go f(wg, b.N/10)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func BenchmarkLocationCache(b *testing.B) {
|
||||
globalLocationCache = newLocationCache()
|
||||
for i := 0; i < b.N; i++ {
|
||||
globalLocationCache.getLocation(rand.Intn(10000))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLocationCacheMultiThread(b *testing.B) {
|
||||
oldProcs := runtime.GOMAXPROCS(0)
|
||||
defer runtime.GOMAXPROCS(oldProcs)
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
globalLocationCache = newLocationCache()
|
||||
|
||||
f := func(wg *sync.WaitGroup, loops int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < loops; i++ {
|
||||
globalLocationCache.getLocation(rand.Intn(10000))
|
||||
}
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
b.ResetTimer()
|
||||
for j := 0; j < 10; j++ {
|
||||
wg.Add(1)
|
||||
go f(wg, b.N/10)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress test the performance of parsing results from the wire.
|
||||
func BenchmarkResultParsing(b *testing.B) {
|
||||
b.StopTimer()
|
||||
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
_, err := db.Exec("BEGIN")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
res, err := db.Query("SELECT generate_series(1, 50000)")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
res.Close()
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
type readBuf []byte
|
||||
|
||||
func (b *readBuf) int32() (n int) {
|
||||
n = int(int32(binary.BigEndian.Uint32(*b)))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) oid() (n oid.Oid) {
|
||||
n = oid.Oid(binary.BigEndian.Uint32(*b))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
// N.B: this is actually an unsigned 16-bit integer, unlike int32
|
||||
func (b *readBuf) int16() (n int) {
|
||||
n = int(binary.BigEndian.Uint16(*b))
|
||||
*b = (*b)[2:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) string() string {
|
||||
i := bytes.IndexByte(*b, 0)
|
||||
if i < 0 {
|
||||
errorf("invalid message format; expected string terminator")
|
||||
}
|
||||
s := (*b)[:i]
|
||||
*b = (*b)[i+1:]
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (b *readBuf) next(n int) (v []byte) {
|
||||
v = (*b)[:n]
|
||||
*b = (*b)[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) byte() byte {
|
||||
return b.next(1)[0]
|
||||
}
|
||||
|
||||
type writeBuf struct {
|
||||
buf []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (b *writeBuf) int32(n int) {
|
||||
x := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(x, uint32(n))
|
||||
b.buf = append(b.buf, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) int16(n int) {
|
||||
x := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(x, uint16(n))
|
||||
b.buf = append(b.buf, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) string(s string) {
|
||||
b.buf = append(b.buf, (s + "\000")...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) byte(c byte) {
|
||||
b.buf = append(b.buf, c)
|
||||
}
|
||||
|
||||
func (b *writeBuf) bytes(v []byte) {
|
||||
b.buf = append(b.buf, v...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) wrap() []byte {
|
||||
p := b.buf[b.pos:]
|
||||
binary.BigEndian.PutUint32(p, uint32(len(p)))
|
||||
return b.buf
|
||||
}
|
||||
|
||||
func (b *writeBuf) next(c byte) {
|
||||
p := b.buf[b.pos:]
|
||||
binary.BigEndian.PutUint32(p, uint32(len(p)))
|
||||
b.pos = len(b.buf) + 1
|
||||
b.buf = append(b.buf, c, 0, 0, 0, 0)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,269 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
errCopyInClosed = errors.New("pq: copyin statement has already been closed")
|
||||
errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
|
||||
errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
|
||||
errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
|
||||
errCopyInProgress = errors.New("pq: COPY in progress")
|
||||
)
|
||||
|
||||
// CopyIn creates a COPY FROM statement which can be prepared with
|
||||
// Tx.Prepare(). The target table should be visible in search_path.
|
||||
func CopyIn(table string, columns ...string) string {
|
||||
stmt := "COPY " + QuoteIdentifier(table) + " ("
|
||||
for i, col := range columns {
|
||||
if i != 0 {
|
||||
stmt += ", "
|
||||
}
|
||||
stmt += QuoteIdentifier(col)
|
||||
}
|
||||
stmt += ") FROM STDIN"
|
||||
return stmt
|
||||
}
|
||||
|
||||
// CopyInSchema creates a COPY FROM statement which can be prepared with
|
||||
// Tx.Prepare().
|
||||
func CopyInSchema(schema, table string, columns ...string) string {
|
||||
stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
|
||||
for i, col := range columns {
|
||||
if i != 0 {
|
||||
stmt += ", "
|
||||
}
|
||||
stmt += QuoteIdentifier(col)
|
||||
}
|
||||
stmt += ") FROM STDIN"
|
||||
return stmt
|
||||
}
|
||||
|
||||
type copyin struct {
|
||||
cn *conn
|
||||
buffer []byte
|
||||
rowData chan []byte
|
||||
done chan bool
|
||||
|
||||
closed bool
|
||||
|
||||
sync.Mutex // guards err
|
||||
err error
|
||||
}
|
||||
|
||||
const ciBufferSize = 64 * 1024
|
||||
|
||||
// flush buffer before the buffer is filled up and needs reallocation
|
||||
const ciBufferFlushSize = 63 * 1024
|
||||
|
||||
func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
|
||||
if !cn.isInTransaction() {
|
||||
return nil, errCopyNotSupportedOutsideTxn
|
||||
}
|
||||
|
||||
ci := ©in{
|
||||
cn: cn,
|
||||
buffer: make([]byte, 0, ciBufferSize),
|
||||
rowData: make(chan []byte),
|
||||
done: make(chan bool, 1),
|
||||
}
|
||||
// add CopyData identifier + 4 bytes for message length
|
||||
ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
|
||||
|
||||
b := cn.writeBuf('Q')
|
||||
b.string(q)
|
||||
cn.send(b)
|
||||
|
||||
awaitCopyInResponse:
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'G':
|
||||
if r.byte() != 0 {
|
||||
err = errBinaryCopyNotSupported
|
||||
break awaitCopyInResponse
|
||||
}
|
||||
go ci.resploop()
|
||||
return ci, nil
|
||||
case 'H':
|
||||
err = errCopyToNotSupported
|
||||
break awaitCopyInResponse
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'Z':
|
||||
if err == nil {
|
||||
cn.bad = true
|
||||
errorf("unexpected ReadyForQuery in response to COPY")
|
||||
}
|
||||
cn.processReadyForQuery(r)
|
||||
return nil, err
|
||||
default:
|
||||
cn.bad = true
|
||||
errorf("unknown response for copy query: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
// something went wrong, abort COPY before we return
|
||||
b = cn.writeBuf('f')
|
||||
b.string(err.Error())
|
||||
cn.send(b)
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'c', 'C', 'E':
|
||||
case 'Z':
|
||||
// correctly aborted, we're done
|
||||
cn.processReadyForQuery(r)
|
||||
return nil, err
|
||||
default:
|
||||
cn.bad = true
|
||||
errorf("unknown response for CopyFail: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *copyin) flush(buf []byte) {
|
||||
// set message length (without message identifier)
|
||||
binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
|
||||
|
||||
_, err := ci.cn.c.Write(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *copyin) resploop() {
|
||||
for {
|
||||
var r readBuf
|
||||
t, err := ci.cn.recvMessage(&r)
|
||||
if err != nil {
|
||||
ci.cn.bad = true
|
||||
ci.setError(err)
|
||||
ci.done <- true
|
||||
return
|
||||
}
|
||||
switch t {
|
||||
case 'C':
|
||||
// complete
|
||||
case 'N':
|
||||
// NoticeResponse
|
||||
case 'Z':
|
||||
ci.cn.processReadyForQuery(&r)
|
||||
ci.done <- true
|
||||
return
|
||||
case 'E':
|
||||
err := parseError(&r)
|
||||
ci.setError(err)
|
||||
default:
|
||||
ci.cn.bad = true
|
||||
ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
|
||||
ci.done <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ci *copyin) isErrorSet() bool {
|
||||
ci.Lock()
|
||||
isSet := (ci.err != nil)
|
||||
ci.Unlock()
|
||||
return isSet
|
||||
}
|
||||
|
||||
// setError() sets ci.err if one has not been set already. Caller must not be
|
||||
// holding ci.Mutex.
|
||||
func (ci *copyin) setError(err error) {
|
||||
ci.Lock()
|
||||
if ci.err == nil {
|
||||
ci.err = err
|
||||
}
|
||||
ci.Unlock()
|
||||
}
|
||||
|
||||
func (ci *copyin) NumInput() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
|
||||
return nil, ErrNotSupported
|
||||
}
|
||||
|
||||
// Exec inserts values into the COPY stream. The insert is asynchronous
|
||||
// and Exec can return errors from previous Exec calls to the same
|
||||
// COPY stmt.
|
||||
//
|
||||
// You need to call Exec(nil) to sync the COPY stream and to get any
|
||||
// errors from pending data, since Stmt.Close() doesn't return errors
|
||||
// to the user.
|
||||
func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
|
||||
if ci.closed {
|
||||
return nil, errCopyInClosed
|
||||
}
|
||||
|
||||
if ci.cn.bad {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
defer ci.cn.errRecover(&err)
|
||||
|
||||
if ci.isErrorSet() {
|
||||
return nil, ci.err
|
||||
}
|
||||
|
||||
if len(v) == 0 {
|
||||
return nil, ci.Close()
|
||||
}
|
||||
|
||||
numValues := len(v)
|
||||
for i, value := range v {
|
||||
ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
|
||||
if i < numValues-1 {
|
||||
ci.buffer = append(ci.buffer, '\t')
|
||||
}
|
||||
}
|
||||
|
||||
ci.buffer = append(ci.buffer, '\n')
|
||||
|
||||
if len(ci.buffer) > ciBufferFlushSize {
|
||||
ci.flush(ci.buffer)
|
||||
// reset buffer, keep bytes for message identifier and length
|
||||
ci.buffer = ci.buffer[:5]
|
||||
}
|
||||
|
||||
return driver.RowsAffected(0), nil
|
||||
}
|
||||
|
||||
func (ci *copyin) Close() (err error) {
|
||||
if ci.closed { // Don't do anything, we're already closed
|
||||
return nil
|
||||
}
|
||||
ci.closed = true
|
||||
|
||||
if ci.cn.bad {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
defer ci.cn.errRecover(&err)
|
||||
|
||||
if len(ci.buffer) > 0 {
|
||||
ci.flush(ci.buffer)
|
||||
}
|
||||
// Avoid touching the scratch buffer as resploop could be using it.
|
||||
err = ci.cn.sendSimpleMessage('c')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-ci.done
|
||||
ci.cn.inCopy = false
|
||||
|
||||
if ci.isErrorSet() {
|
||||
err = ci.err
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,465 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCopyInStmt(t *testing.T) {
|
||||
var stmt string
|
||||
stmt = CopyIn("table name")
|
||||
if stmt != `COPY "table name" () FROM STDIN` {
|
||||
t.Fatal(stmt)
|
||||
}
|
||||
|
||||
stmt = CopyIn("table name", "column 1", "column 2")
|
||||
if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` {
|
||||
t.Fatal(stmt)
|
||||
}
|
||||
|
||||
stmt = CopyIn(`table " name """`, `co"lumn""`)
|
||||
if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` {
|
||||
t.Fatal(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyInSchemaStmt(t *testing.T) {
|
||||
var stmt string
|
||||
stmt = CopyInSchema("schema name", "table name")
|
||||
if stmt != `COPY "schema name"."table name" () FROM STDIN` {
|
||||
t.Fatal(stmt)
|
||||
}
|
||||
|
||||
stmt = CopyInSchema("schema name", "table name", "column 1", "column 2")
|
||||
if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` {
|
||||
t.Fatal(stmt)
|
||||
}
|
||||
|
||||
stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`)
|
||||
if stmt != `COPY "schema "" name """"""".`+
|
||||
`"table "" name """"""" ("co""lumn""""") FROM STDIN` {
|
||||
t.Fatal(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyInMultipleValues(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
longString := strings.Repeat("#", 500)
|
||||
|
||||
for i := 0; i < 500; i++ {
|
||||
_, err = stmt.Exec(int64(i), longString)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var num int
|
||||
err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if num != 500 {
|
||||
t.Fatalf("expected 500 items, not %d", num)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyInRaiseStmtTrigger(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
if getServerVersion(t, db) < 90000 {
|
||||
var exists int
|
||||
err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists)
|
||||
if err == sql.ErrNoRows {
|
||||
t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger")
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
CREATE OR REPLACE FUNCTION pg_temp.temptest()
|
||||
RETURNS trigger AS
|
||||
$BODY$ begin
|
||||
raise notice 'Hello world';
|
||||
return new;
|
||||
end $BODY$
|
||||
LANGUAGE plpgsql`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = txn.Exec(`
|
||||
CREATE TRIGGER temptest_trigger
|
||||
BEFORE INSERT
|
||||
ON temp
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE pg_temp.temptest()`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
longString := strings.Repeat("#", 500)
|
||||
|
||||
_, err = stmt.Exec(int64(1), longString)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var num int
|
||||
err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if num != 1 {
|
||||
t.Fatalf("expected 1 items, not %d", num)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyInTypes(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var num int
|
||||
var text string
|
||||
var blob []byte
|
||||
var nothing sql.NullString
|
||||
|
||||
err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if num != 1234567890 {
|
||||
t.Fatal("unexpected result", num)
|
||||
}
|
||||
if text != "Héllö\n ☃!\r\t\\" {
|
||||
t.Fatal("unexpected result", text)
|
||||
}
|
||||
if bytes.Compare(blob, []byte{0, 255, 9, 10, 13}) != 0 {
|
||||
t.Fatal("unexpected result", blob)
|
||||
}
|
||||
if nothing.Valid {
|
||||
t.Fatal("unexpected result", nothing.String)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyInWrongType(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(CopyIn("temp", "num"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec("Héllö\n ☃!\r\t\\")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" {
|
||||
t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyOutsideOfTxnError(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Prepare(CopyIn("temp", "num"))
|
||||
if err == nil {
|
||||
t.Fatal("COPY outside of transaction did not return an error")
|
||||
}
|
||||
if err != errCopyNotSupportedOutsideTxn {
|
||||
t.Fatalf("expected %s, got %s", err, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyInBinaryError(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary")
|
||||
if err != errBinaryCopyNotSupported {
|
||||
t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err)
|
||||
}
|
||||
// check that the protocol is in a valid state
|
||||
err = txn.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyFromError(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = txn.Prepare("COPY temp (num) TO STDOUT")
|
||||
if err != errCopyToNotSupported {
|
||||
t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err)
|
||||
}
|
||||
// check that the protocol is in a valid state
|
||||
err = txn.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopySyntaxError(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Prepare("COPY ")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if pge := err.(*Error); pge.Code.Name() != "syntax_error" {
|
||||
t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge)
|
||||
}
|
||||
// check that the protocol is in a valid state
|
||||
err = txn.Rollback()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for connection errors in copyin.resploop()
|
||||
func TestCopyRespLoopConnectionError(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
var pid int
|
||||
err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (a int)")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(CopyIn("temp", "a"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = db.Exec("SELECT pg_terminate_backend($1)", pid)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if getServerVersion(t, db) < 90500 {
|
||||
// We have to try and send something over, since postgres before
|
||||
// version 9.5 won't process SIGTERMs while it's waiting for
|
||||
// CopyData/CopyEnd messages; see tcop/postgres.c.
|
||||
_, err = stmt.Exec(1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
_, err = stmt.Exec()
|
||||
if err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
pge, ok := err.(*Error)
|
||||
if !ok {
|
||||
if err == driver.ErrBadConn {
|
||||
// likely an EPIPE
|
||||
} else {
|
||||
t.Fatalf("expected *pq.Error or driver.ErrBadConn, got %+#v", err)
|
||||
}
|
||||
} else if pge.Code.Name() != "admin_shutdown" {
|
||||
t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name())
|
||||
}
|
||||
|
||||
_ = stmt.Close()
|
||||
}
|
||||
|
||||
func BenchmarkCopyIn(b *testing.B) {
|
||||
db := openTestConn(b)
|
||||
defer db.Close()
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err = stmt.Exec(int64(i), "hello world!")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
var num int
|
||||
err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if num != b.N {
|
||||
b.Fatalf("expected %d items, not %d", b.N, num)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
/*
|
||||
Package pq is a pure Go Postgres driver for the database/sql package.
|
||||
|
||||
In most cases clients will use the database/sql package instead of
|
||||
using this package directly. For example:
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
age := 21
|
||||
rows, err := db.Query("SELECT name FROM users WHERE age = $1", age)
|
||||
…
|
||||
}
|
||||
|
||||
You can also connect to a database using a URL. For example:
|
||||
|
||||
db, err := sql.Open("postgres", "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full")
|
||||
|
||||
|
||||
Connection String Parameters
|
||||
|
||||
|
||||
Similarly to libpq, when establishing a connection using pq you are expected to
|
||||
supply a connection string containing zero or more parameters.
|
||||
A subset of the connection parameters supported by libpq are also supported by pq.
|
||||
Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem)
|
||||
directly in the connection string. This is different from libpq, which does not allow
|
||||
run-time parameters in the connection string, instead requiring you to supply
|
||||
them in the options parameter.
|
||||
|
||||
For compatibility with libpq, the following special connection parameters are
|
||||
supported:
|
||||
|
||||
* dbname - The name of the database to connect to
|
||||
* user - The user to sign in as
|
||||
* password - The user's password
|
||||
* host - The host to connect to. Values that start with / are for unix domain sockets. (default is localhost)
|
||||
* port - The port to bind to. (default is 5432)
|
||||
* sslmode - Whether or not to use SSL (default is require, this is not the default for libpq)
|
||||
* fallback_application_name - An application_name to fall back to if one isn't provided.
|
||||
* connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely.
|
||||
* sslcert - Cert file location. The file must contain PEM encoded data.
|
||||
* sslkey - Key file location. The file must contain PEM encoded data.
|
||||
* sslrootcert - The location of the root certificate file. The file must contain PEM encoded data.
|
||||
|
||||
Valid values for sslmode are:
|
||||
|
||||
* disable - No SSL
|
||||
* require - Always SSL (skip verification)
|
||||
* verify-ca - Always SSL (verify that the certificate presented by the server was signed by a trusted CA)
|
||||
* verify-full - Always SSL (verify that the certification presented by the server was signed by a trusted CA and the server host name matches the one in the certificate)
|
||||
|
||||
See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
|
||||
for more information about connection string parameters.
|
||||
|
||||
Use single quotes for values that contain whitespace:
|
||||
|
||||
"user=pqgotest password='with spaces'"
|
||||
|
||||
A backslash will escape the next character in values:
|
||||
|
||||
"user=space\ man password='it\'s valid'
|
||||
|
||||
Note that the connection parameter client_encoding (which sets the
|
||||
text encoding for the connection) may be set but must be "UTF8",
|
||||
matching with the same rules as Postgres. It is an error to provide
|
||||
any other value.
|
||||
|
||||
In addition to the parameters listed above, any run-time parameter that can be
|
||||
set at backend start time can be set in the connection string. For more
|
||||
information, see
|
||||
http://www.postgresql.org/docs/current/static/runtime-config.html.
|
||||
|
||||
Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html
|
||||
supported by libpq are also supported by pq. If any of the environment
|
||||
variables not supported by pq are set, pq will panic during connection
|
||||
establishment. Environment variables have a lower precedence than explicitly
|
||||
provided connection parameters.
|
||||
|
||||
The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html
|
||||
is supported, but on Windows PGPASSFILE must be specified explicitly.
|
||||
|
||||
|
||||
Queries
|
||||
|
||||
|
||||
database/sql does not dictate any specific format for parameter
|
||||
markers in query strings, and pq uses the Postgres-native ordinal markers,
|
||||
as shown above. The same marker can be reused for the same parameter:
|
||||
|
||||
rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1
|
||||
OR age BETWEEN $2 AND $2 + 3`, "orange", 64)
|
||||
|
||||
pq does not support the LastInsertId() method of the Result type in database/sql.
|
||||
To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres
|
||||
RETURNING clause with a standard Query or QueryRow call:
|
||||
|
||||
var userid int
|
||||
err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age)
|
||||
VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid)
|
||||
|
||||
For more details on RETURNING, see the Postgres documentation:
|
||||
|
||||
http://www.postgresql.org/docs/current/static/sql-insert.html
|
||||
http://www.postgresql.org/docs/current/static/sql-update.html
|
||||
http://www.postgresql.org/docs/current/static/sql-delete.html
|
||||
|
||||
For additional instructions on querying see the documentation for the database/sql package.
|
||||
|
||||
|
||||
Data Types
|
||||
|
||||
|
||||
Parameters pass through driver.DefaultParameterConverter before they are handled
|
||||
by this package. When the binary_parameters connection option is enabled,
|
||||
[]byte values are sent directly to the backend as data in binary format.
|
||||
|
||||
This package returns the following types for values from the PostgreSQL backend:
|
||||
|
||||
- integer types smallint, integer, and bigint are returned as int64
|
||||
- floating-point types real and double precision are returned as float64
|
||||
- character types char, varchar, and text are returned as string
|
||||
- temporal types date, time, timetz, timestamp, and timestamptz are returned as time.Time
|
||||
- the boolean type is returned as bool
|
||||
- the bytea type is returned as []byte
|
||||
|
||||
All other types are returned directly from the backend as []byte values in text format.
|
||||
|
||||
|
||||
Errors
|
||||
|
||||
|
||||
pq may return errors of type *pq.Error which can be interrogated for error details:
|
||||
|
||||
if err, ok := err.(*pq.Error); ok {
|
||||
fmt.Println("pq error:", err.Code.Name())
|
||||
}
|
||||
|
||||
See the pq.Error type for details.
|
||||
|
||||
|
||||
Bulk imports
|
||||
|
||||
You can perform bulk imports by preparing a statement returned by pq.CopyIn (or
|
||||
pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement
|
||||
handle can then be repeatedly "executed" to copy data into the target table.
|
||||
After all data has been processed you should call Exec() once with no arguments
|
||||
to flush all buffered data. Any call to Exec() might return an error which
|
||||
should be handled appropriately, but because of the internal buffering an error
|
||||
returned by Exec() might not be related to the data passed in the call that
|
||||
failed.
|
||||
|
||||
CopyIn uses COPY FROM internally. It is not possible to COPY outside of an
|
||||
explicit transaction in pq.
|
||||
|
||||
Usage example:
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age"))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
_, err = stmt.Exec(user.Name, int64(user.Age))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = txn.Commit()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
Notifications
|
||||
|
||||
|
||||
PostgreSQL supports a simple publish/subscribe model over database
|
||||
connections. See http://www.postgresql.org/docs/current/static/sql-notify.html
|
||||
for more information about the general mechanism.
|
||||
|
||||
To start listening for notifications, you first have to open a new connection
|
||||
to the database by calling NewListener. This connection can not be used for
|
||||
anything other than LISTEN / NOTIFY. Calling Listen will open a "notification
|
||||
channel"; once a notification channel is open, a notification generated on that
|
||||
channel will effect a send on the Listener.Notify channel. A notification
|
||||
channel will remain open until Unlisten is called, though connection loss might
|
||||
result in some notifications being lost. To solve this problem, Listener sends
|
||||
a nil pointer over the Notify channel any time the connection is re-established
|
||||
following a connection loss. The application can get information about the
|
||||
state of the underlying connection by setting an event callback in the call to
|
||||
NewListener.
|
||||
|
||||
A single Listener can safely be used from concurrent goroutines, which means
|
||||
that there is often no need to create more than one Listener in your
|
||||
application. However, a Listener is always connected to a single database, so
|
||||
you will need to create a new Listener instance for every database you want to
|
||||
receive notifications in.
|
||||
|
||||
The channel name in both Listen and Unlisten is case sensitive, and can contain
|
||||
any characters legal in an identifier (see
|
||||
http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
for more information). Note that the channel name will be truncated to 63
|
||||
bytes by the PostgreSQL server.
|
||||
|
||||
You can find a complete, working example of Listener usage at
|
||||
http://godoc.org/github.com/lib/pq/listen_example.
|
||||
|
||||
*/
|
||||
package pq
|
||||
|
|
@ -0,0 +1,589 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
|
||||
switch v := x.(type) {
|
||||
case []byte:
|
||||
return v
|
||||
default:
|
||||
return encode(parameterStatus, x, oid.T_unknown)
|
||||
}
|
||||
}
|
||||
|
||||
func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return strconv.AppendInt(nil, v, 10)
|
||||
case float64:
|
||||
return strconv.AppendFloat(nil, v, 'f', -1, 64)
|
||||
case []byte:
|
||||
if pgtypOid == oid.T_bytea {
|
||||
return encodeBytea(parameterStatus.serverVersion, v)
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
if pgtypOid == oid.T_bytea {
|
||||
return encodeBytea(parameterStatus.serverVersion, []byte(v))
|
||||
}
|
||||
|
||||
return []byte(v)
|
||||
case bool:
|
||||
return strconv.AppendBool(nil, v)
|
||||
case time.Time:
|
||||
return formatTs(v)
|
||||
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
|
||||
switch f {
|
||||
case formatBinary:
|
||||
return binaryDecode(parameterStatus, s, typ)
|
||||
case formatText:
|
||||
return textDecode(parameterStatus, s, typ)
|
||||
default:
|
||||
panic("not reached")
|
||||
}
|
||||
}
|
||||
|
||||
func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
|
||||
switch typ {
|
||||
case oid.T_bytea:
|
||||
return s
|
||||
case oid.T_int8:
|
||||
return int64(binary.BigEndian.Uint64(s))
|
||||
case oid.T_int4:
|
||||
return int64(int32(binary.BigEndian.Uint32(s)))
|
||||
case oid.T_int2:
|
||||
return int64(int16(binary.BigEndian.Uint16(s)))
|
||||
|
||||
default:
|
||||
errorf("don't know how to decode binary parameter of type %d", uint32(typ))
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
|
||||
switch typ {
|
||||
case oid.T_char, oid.T_varchar, oid.T_text:
|
||||
return string(s)
|
||||
case oid.T_bytea:
|
||||
b, err := parseBytea(s)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return b
|
||||
case oid.T_timestamptz:
|
||||
return parseTs(parameterStatus.currentLocation, string(s))
|
||||
case oid.T_timestamp, oid.T_date:
|
||||
return parseTs(nil, string(s))
|
||||
case oid.T_time:
|
||||
return mustParse("15:04:05", typ, s)
|
||||
case oid.T_timetz:
|
||||
return mustParse("15:04:05-07", typ, s)
|
||||
case oid.T_bool:
|
||||
return s[0] == 't'
|
||||
case oid.T_int8, oid.T_int4, oid.T_int2:
|
||||
i, err := strconv.ParseInt(string(s), 10, 64)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return i
|
||||
case oid.T_float4, oid.T_float8:
|
||||
bits := 64
|
||||
if typ == oid.T_float4 {
|
||||
bits = 32
|
||||
}
|
||||
f, err := strconv.ParseFloat(string(s), bits)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// appendEncodedText encodes item in text format as required by COPY
|
||||
// and appends to buf
|
||||
func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return strconv.AppendInt(buf, v, 10)
|
||||
case float64:
|
||||
return strconv.AppendFloat(buf, v, 'f', -1, 64)
|
||||
case []byte:
|
||||
encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
|
||||
return appendEscapedText(buf, string(encodedBytea))
|
||||
case string:
|
||||
return appendEscapedText(buf, v)
|
||||
case bool:
|
||||
return strconv.AppendBool(buf, v)
|
||||
case time.Time:
|
||||
return append(buf, formatTs(v)...)
|
||||
case nil:
|
||||
return append(buf, "\\N"...)
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func appendEscapedText(buf []byte, text string) []byte {
|
||||
escapeNeeded := false
|
||||
startPos := 0
|
||||
var c byte
|
||||
|
||||
// check if we need to escape
|
||||
for i := 0; i < len(text); i++ {
|
||||
c = text[i]
|
||||
if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
|
||||
escapeNeeded = true
|
||||
startPos = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if !escapeNeeded {
|
||||
return append(buf, text...)
|
||||
}
|
||||
|
||||
// copy till first char to escape, iterate the rest
|
||||
result := append(buf, text[:startPos]...)
|
||||
for i := startPos; i < len(text); i++ {
|
||||
c = text[i]
|
||||
switch c {
|
||||
case '\\':
|
||||
result = append(result, '\\', '\\')
|
||||
case '\n':
|
||||
result = append(result, '\\', 'n')
|
||||
case '\r':
|
||||
result = append(result, '\\', 'r')
|
||||
case '\t':
|
||||
result = append(result, '\\', 't')
|
||||
default:
|
||||
result = append(result, c)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mustParse(f string, typ oid.Oid, s []byte) time.Time {
|
||||
str := string(s)
|
||||
|
||||
// check for a 30-minute-offset timezone
|
||||
if (typ == oid.T_timestamptz || typ == oid.T_timetz) &&
|
||||
str[len(str)-3] == ':' {
|
||||
f += ":00"
|
||||
}
|
||||
t, err := time.Parse(f, str)
|
||||
if err != nil {
|
||||
errorf("decode: %s", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
var errInvalidTimestamp = errors.New("invalid timestamp")
|
||||
|
||||
type timestampParser struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (p *timestampParser) expect(str string, char byte, pos int) {
|
||||
if p.err != nil {
|
||||
return
|
||||
}
|
||||
if pos+1 > len(str) {
|
||||
p.err = errInvalidTimestamp
|
||||
return
|
||||
}
|
||||
if c := str[pos]; c != char && p.err == nil {
|
||||
p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
|
||||
if p.err != nil {
|
||||
return 0
|
||||
}
|
||||
if begin < 0 || end < 0 || begin > end || end > len(str) {
|
||||
p.err = errInvalidTimestamp
|
||||
return 0
|
||||
}
|
||||
result, err := strconv.Atoi(str[begin:end])
|
||||
if err != nil {
|
||||
if p.err == nil {
|
||||
p.err = fmt.Errorf("expected number; got '%v'", str)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// The location cache caches the time zones typically used by the client.
|
||||
type locationCache struct {
|
||||
cache map[int]*time.Location
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// All connections share the same list of timezones. Benchmarking shows that
|
||||
// about 5% speed could be gained by putting the cache in the connection and
|
||||
// losing the mutex, at the cost of a small amount of memory and a somewhat
|
||||
// significant increase in code complexity.
|
||||
var globalLocationCache = newLocationCache()
|
||||
|
||||
func newLocationCache() *locationCache {
|
||||
return &locationCache{cache: make(map[int]*time.Location)}
|
||||
}
|
||||
|
||||
// Returns the cached timezone for the specified offset, creating and caching
|
||||
// it if necessary.
|
||||
func (c *locationCache) getLocation(offset int) *time.Location {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
location, ok := c.cache[offset]
|
||||
if !ok {
|
||||
location = time.FixedZone("", offset)
|
||||
c.cache[offset] = location
|
||||
}
|
||||
|
||||
return location
|
||||
}
|
||||
|
||||
var infinityTsEnabled = false
|
||||
var infinityTsNegative time.Time
|
||||
var infinityTsPositive time.Time
|
||||
|
||||
const (
|
||||
infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
|
||||
infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
|
||||
)
|
||||
|
||||
// EnableInfinityTs controls the handling of Postgres' "-infinity" and
|
||||
// "infinity" "timestamp"s.
|
||||
//
|
||||
// If EnableInfinityTs is not called, "-infinity" and "infinity" will return
|
||||
// []byte("-infinity") and []byte("infinity") respectively, and potentially
|
||||
// cause error "sql: Scan error on column index 0: unsupported driver -> Scan
|
||||
// pair: []uint8 -> *time.Time", when scanning into a time.Time value.
|
||||
//
|
||||
// Once EnableInfinityTs has been called, all connections created using this
|
||||
// driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
|
||||
// "timestamp with time zone" and "date" types to the predefined minimum and
|
||||
// maximum times, respectively. When encoding time.Time values, any time which
|
||||
// equals or precedes the predefined minimum time will be encoded to
|
||||
// "-infinity". Any values at or past the maximum time will similarly be
|
||||
// encoded to "infinity".
|
||||
//
|
||||
// If EnableInfinityTs is called with negative >= positive, it will panic.
|
||||
// Calling EnableInfinityTs after a connection has been established results in
|
||||
// undefined behavior. If EnableInfinityTs is called more than once, it will
|
||||
// panic.
|
||||
func EnableInfinityTs(negative time.Time, positive time.Time) {
|
||||
if infinityTsEnabled {
|
||||
panic(infinityTsEnabledAlready)
|
||||
}
|
||||
if !negative.Before(positive) {
|
||||
panic(infinityTsNegativeMustBeSmaller)
|
||||
}
|
||||
infinityTsEnabled = true
|
||||
infinityTsNegative = negative
|
||||
infinityTsPositive = positive
|
||||
}
|
||||
|
||||
/*
|
||||
* Testing might want to toggle infinityTsEnabled
|
||||
*/
|
||||
func disableInfinityTs() {
|
||||
infinityTsEnabled = false
|
||||
}
|
||||
|
||||
// This is a time function specific to the Postgres default DateStyle
|
||||
// setting ("ISO, MDY"), the only one we currently support. This
|
||||
// accounts for the discrepancies between the parsing available with
|
||||
// time.Parse and the Postgres date formatting quirks.
|
||||
func parseTs(currentLocation *time.Location, str string) interface{} {
|
||||
switch str {
|
||||
case "-infinity":
|
||||
if infinityTsEnabled {
|
||||
return infinityTsNegative
|
||||
}
|
||||
return []byte(str)
|
||||
case "infinity":
|
||||
if infinityTsEnabled {
|
||||
return infinityTsPositive
|
||||
}
|
||||
return []byte(str)
|
||||
}
|
||||
t, err := ParseTimestamp(currentLocation, str)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// ParseTimestamp parses Postgres' text format. It returns a time.Time in
|
||||
// currentLocation iff that time's offset agrees with the offset sent from the
|
||||
// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
|
||||
// fixed offset offset provided by the Postgres server.
|
||||
func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
|
||||
p := timestampParser{}
|
||||
|
||||
monSep := strings.IndexRune(str, '-')
|
||||
// this is Gregorian year, not ISO Year
|
||||
// In Gregorian system, the year 1 BC is followed by AD 1
|
||||
year := p.mustAtoi(str, 0, monSep)
|
||||
daySep := monSep + 3
|
||||
month := p.mustAtoi(str, monSep+1, daySep)
|
||||
p.expect(str, '-', daySep)
|
||||
timeSep := daySep + 3
|
||||
day := p.mustAtoi(str, daySep+1, timeSep)
|
||||
|
||||
var hour, minute, second int
|
||||
if len(str) > monSep+len("01-01")+1 {
|
||||
p.expect(str, ' ', timeSep)
|
||||
minSep := timeSep + 3
|
||||
p.expect(str, ':', minSep)
|
||||
hour = p.mustAtoi(str, timeSep+1, minSep)
|
||||
secSep := minSep + 3
|
||||
p.expect(str, ':', secSep)
|
||||
minute = p.mustAtoi(str, minSep+1, secSep)
|
||||
secEnd := secSep + 3
|
||||
second = p.mustAtoi(str, secSep+1, secEnd)
|
||||
}
|
||||
remainderIdx := monSep + len("01-01 00:00:00") + 1
|
||||
// Three optional (but ordered) sections follow: the
|
||||
// fractional seconds, the time zone offset, and the BC
|
||||
// designation. We set them up here and adjust the other
|
||||
// offsets if the preceding sections exist.
|
||||
|
||||
nanoSec := 0
|
||||
tzOff := 0
|
||||
|
||||
if remainderIdx < len(str) && str[remainderIdx] == '.' {
|
||||
fracStart := remainderIdx + 1
|
||||
fracOff := strings.IndexAny(str[fracStart:], "-+ ")
|
||||
if fracOff < 0 {
|
||||
fracOff = len(str) - fracStart
|
||||
}
|
||||
fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
|
||||
nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
|
||||
|
||||
remainderIdx += fracOff + 1
|
||||
}
|
||||
if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
|
||||
// time zone separator is always '-' or '+' (UTC is +00)
|
||||
var tzSign int
|
||||
switch c := str[tzStart]; c {
|
||||
case '-':
|
||||
tzSign = -1
|
||||
case '+':
|
||||
tzSign = +1
|
||||
default:
|
||||
return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
|
||||
}
|
||||
tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
|
||||
remainderIdx += 3
|
||||
var tzMin, tzSec int
|
||||
if remainderIdx < len(str) && str[remainderIdx] == ':' {
|
||||
tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
|
||||
remainderIdx += 3
|
||||
}
|
||||
if remainderIdx < len(str) && str[remainderIdx] == ':' {
|
||||
tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
|
||||
remainderIdx += 3
|
||||
}
|
||||
tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
|
||||
}
|
||||
var isoYear int
|
||||
if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" {
|
||||
isoYear = 1 - year
|
||||
remainderIdx += 3
|
||||
} else {
|
||||
isoYear = year
|
||||
}
|
||||
if remainderIdx < len(str) {
|
||||
return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
|
||||
}
|
||||
t := time.Date(isoYear, time.Month(month), day,
|
||||
hour, minute, second, nanoSec,
|
||||
globalLocationCache.getLocation(tzOff))
|
||||
|
||||
if currentLocation != nil {
|
||||
// Set the location of the returned Time based on the session's
|
||||
// TimeZone value, but only if the local time zone database agrees with
|
||||
// the remote database on the offset.
|
||||
lt := t.In(currentLocation)
|
||||
_, newOff := lt.Zone()
|
||||
if newOff == tzOff {
|
||||
t = lt
|
||||
}
|
||||
}
|
||||
|
||||
return t, p.err
|
||||
}
|
||||
|
||||
// formatTs formats t into a format postgres understands.
|
||||
func formatTs(t time.Time) []byte {
|
||||
if infinityTsEnabled {
|
||||
// t <= -infinity : ! (t > -infinity)
|
||||
if !t.After(infinityTsNegative) {
|
||||
return []byte("-infinity")
|
||||
}
|
||||
// t >= infinity : ! (!t < infinity)
|
||||
if !t.Before(infinityTsPositive) {
|
||||
return []byte("infinity")
|
||||
}
|
||||
}
|
||||
return FormatTimestamp(t)
|
||||
}
|
||||
|
||||
// FormatTimestamp formats t into Postgres' text format for timestamps.
|
||||
func FormatTimestamp(t time.Time) []byte {
|
||||
// Need to send dates before 0001 A.D. with " BC" suffix, instead of the
|
||||
// minus sign preferred by Go.
|
||||
// Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
|
||||
bc := false
|
||||
if t.Year() <= 0 {
|
||||
// flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
|
||||
t = t.AddDate((-t.Year())*2+1, 0, 0)
|
||||
bc = true
|
||||
}
|
||||
b := []byte(t.Format(time.RFC3339Nano))
|
||||
|
||||
_, offset := t.Zone()
|
||||
offset = offset % 60
|
||||
if offset != 0 {
|
||||
// RFC3339Nano already printed the minus sign
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
b = append(b, ':')
|
||||
if offset < 10 {
|
||||
b = append(b, '0')
|
||||
}
|
||||
b = strconv.AppendInt(b, int64(offset), 10)
|
||||
}
|
||||
|
||||
if bc {
|
||||
b = append(b, " BC"...)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Parse a bytea value received from the server. Both "hex" and the legacy
|
||||
// "escape" format are supported.
|
||||
func parseBytea(s []byte) (result []byte, err error) {
|
||||
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
|
||||
// bytea_output = hex
|
||||
s = s[2:] // trim off leading "\\x"
|
||||
result = make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(result, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// bytea_output = escape
|
||||
for len(s) > 0 {
|
||||
if s[0] == '\\' {
|
||||
// escaped '\\'
|
||||
if len(s) >= 2 && s[1] == '\\' {
|
||||
result = append(result, '\\')
|
||||
s = s[2:]
|
||||
continue
|
||||
}
|
||||
|
||||
// '\\' followed by an octal number
|
||||
if len(s) < 4 {
|
||||
return nil, fmt.Errorf("invalid bytea sequence %v", s)
|
||||
}
|
||||
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
|
||||
}
|
||||
result = append(result, byte(r))
|
||||
s = s[4:]
|
||||
} else {
|
||||
// We hit an unescaped, raw byte. Try to read in as many as
|
||||
// possible in one go.
|
||||
i := bytes.IndexByte(s, '\\')
|
||||
if i == -1 {
|
||||
result = append(result, s...)
|
||||
break
|
||||
}
|
||||
result = append(result, s[:i]...)
|
||||
s = s[i:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func encodeBytea(serverVersion int, v []byte) (result []byte) {
|
||||
if serverVersion >= 90000 {
|
||||
// Use the hex format if we know that the server supports it
|
||||
result = make([]byte, 2+hex.EncodedLen(len(v)))
|
||||
result[0] = '\\'
|
||||
result[1] = 'x'
|
||||
hex.Encode(result[2:], v)
|
||||
} else {
|
||||
// .. or resort to "escape"
|
||||
for _, b := range v {
|
||||
if b == '\\' {
|
||||
result = append(result, '\\', '\\')
|
||||
} else if b < 0x20 || b > 0x7e {
|
||||
result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
|
||||
} else {
|
||||
result = append(result, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NullTime represents a time.Time that may be null. NullTime implements the
|
||||
// sql.Scanner interface so it can be used as a scan destination, similar to
|
||||
// sql.NullString.
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
nt.Time, nt.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,738 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq/oid"
|
||||
)
|
||||
|
||||
func TestScanTimestamp(t *testing.T) {
|
||||
var nt NullTime
|
||||
tn := time.Now()
|
||||
nt.Scan(tn)
|
||||
if !nt.Valid {
|
||||
t.Errorf("Expected Valid=false")
|
||||
}
|
||||
if nt.Time != tn {
|
||||
t.Errorf("Time value mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanNilTimestamp(t *testing.T) {
|
||||
var nt NullTime
|
||||
nt.Scan(nil)
|
||||
if nt.Valid {
|
||||
t.Errorf("Expected Valid=false")
|
||||
}
|
||||
}
|
||||
|
||||
var timeTests = []struct {
|
||||
str string
|
||||
timeval time.Time
|
||||
}{
|
||||
{"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.FixedZone("", 0))},
|
||||
{"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000,
|
||||
time.FixedZone("", -7*60*60))},
|
||||
{"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", -7*60*60))},
|
||||
{"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", -(7*60*60+42*60)))},
|
||||
{"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", -(7*60*60+30*60+9)))},
|
||||
{"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0,
|
||||
time.FixedZone("", 7*60*60))},
|
||||
{"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))},
|
||||
{"0011-02-03 04:05:06.123 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"0011-02-03 04:05:06.123-07 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000,
|
||||
time.FixedZone("", -7*60*60))},
|
||||
{"0001-02-03 04:05:06.123", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"0001-02-03 04:05:06.123 BC", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)},
|
||||
{"0001-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"0002-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)},
|
||||
{"0002-02-03 04:05:06.123 BC", time.Date(-1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))},
|
||||
{"12345-02-03 04:05:06.1", time.Date(12345, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))},
|
||||
{"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))},
|
||||
}
|
||||
|
||||
// Test that parsing the string results in the expected value.
|
||||
func TestParseTs(t *testing.T) {
|
||||
for i, tt := range timeTests {
|
||||
val, err := ParseTimestamp(nil, tt.str)
|
||||
if err != nil {
|
||||
t.Errorf("%d: got error: %v", i, err)
|
||||
} else if val.String() != tt.timeval.String() {
|
||||
t.Errorf("%d: expected to parse %q into %q; got %q",
|
||||
i, tt.str, tt.timeval, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var timeErrorTests = []string{
|
||||
"2001",
|
||||
"2001-2-03",
|
||||
"2001-02-3",
|
||||
"2001-02-03 ",
|
||||
"2001-02-03 04",
|
||||
"2001-02-03 04:",
|
||||
"2001-02-03 04:05",
|
||||
"2001-02-03 04:05:",
|
||||
"2001-02-03 04:05:6",
|
||||
"2001-02-03 04:05:06.123 B",
|
||||
}
|
||||
|
||||
// Test that parsing the string results in an error.
|
||||
func TestParseTsErrors(t *testing.T) {
|
||||
for i, tt := range timeErrorTests {
|
||||
_, err := ParseTimestamp(nil, tt)
|
||||
if err == nil {
|
||||
t.Errorf("%d: expected an error from parsing: %v", i, tt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now test that sending the value into the database and parsing it back
|
||||
// returns the same time.Time value.
|
||||
func TestEncodeAndParseTs(t *testing.T) {
|
||||
db, err := openTestConnConninfo("timezone='Etc/UTC'")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
for i, tt := range timeTests {
|
||||
var dbstr string
|
||||
err = db.QueryRow("SELECT ($1::timestamptz)::text", tt.timeval).Scan(&dbstr)
|
||||
if err != nil {
|
||||
t.Errorf("%d: could not send value %q to the database: %s", i, tt.timeval, err)
|
||||
continue
|
||||
}
|
||||
|
||||
val, err := ParseTimestamp(nil, dbstr)
|
||||
if err != nil {
|
||||
t.Errorf("%d: could not parse value %q: %s", i, dbstr, err)
|
||||
continue
|
||||
}
|
||||
val = val.In(tt.timeval.Location())
|
||||
if val.String() != tt.timeval.String() {
|
||||
t.Errorf("%d: expected to parse %q into %q; got %q", i, dbstr, tt.timeval, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatTimeTests = []struct {
|
||||
time time.Time
|
||||
expected string
|
||||
}{
|
||||
{time.Time{}, "0001-01-01T00:00:00Z"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03T04:05:06.123456789Z"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03T04:05:06.123456789+02:00"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03T04:05:06.123456789-06:00"},
|
||||
{time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03T04:05:06-07:30:09"},
|
||||
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03T04:05:06.123456789Z"},
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03T04:05:06.123456789+02:00"},
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03T04:05:06.123456789-06:00"},
|
||||
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03T04:05:06.123456789Z BC"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03T04:05:06.123456789+02:00 BC"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03T04:05:06.123456789-06:00 BC"},
|
||||
|
||||
{time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03T04:05:06-07:30:09"},
|
||||
{time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03T04:05:06-07:30:09 BC"},
|
||||
}
|
||||
|
||||
func TestFormatTs(t *testing.T) {
|
||||
for i, tt := range formatTimeTests {
|
||||
val := string(formatTs(tt.time))
|
||||
if val != tt.expected {
|
||||
t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampWithTimeZone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// try several different locations, all included in Go's zoneinfo.zip
|
||||
for _, locName := range []string{
|
||||
"UTC",
|
||||
"America/Chicago",
|
||||
"America/New_York",
|
||||
"Australia/Darwin",
|
||||
"Australia/Perth",
|
||||
} {
|
||||
loc, err := time.LoadLocation(locName)
|
||||
if err != nil {
|
||||
t.Logf("Could not load time zone %s - skipping", locName)
|
||||
continue
|
||||
}
|
||||
|
||||
// Postgres timestamps have a resolution of 1 microsecond, so don't
|
||||
// use the full range of the Nanosecond argument
|
||||
refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc)
|
||||
|
||||
for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} {
|
||||
// Switch Postgres's timezone to test different output timestamp formats
|
||||
_, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var gotTime time.Time
|
||||
row := tx.QueryRow("select $1::timestamp with time zone", refTime)
|
||||
err = row.Scan(&gotTime)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !refTime.Equal(gotTime) {
|
||||
t.Errorf("timestamps not equal: %s != %s", refTime, gotTime)
|
||||
}
|
||||
|
||||
// check that the time zone is set correctly based on TimeZone
|
||||
pgLoc, err := time.LoadLocation(pgTimeZone)
|
||||
if err != nil {
|
||||
t.Logf("Could not load time zone %s - skipping", pgLoc)
|
||||
continue
|
||||
}
|
||||
translated := refTime.In(pgLoc)
|
||||
if translated.String() != gotTime.String() {
|
||||
t.Errorf("timestamps not equal: %s != %s", translated, gotTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampWithOutTimezone(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
test := func(ts, pgts string) {
|
||||
r, err := db.Query("SELECT $1::timestamp", pgts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not run query: %v", err)
|
||||
}
|
||||
|
||||
n := r.Next()
|
||||
|
||||
if n != true {
|
||||
t.Fatal("Expected at least one row")
|
||||
}
|
||||
|
||||
var result time.Time
|
||||
err = r.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatalf("Did not expect error scanning row: %v", err)
|
||||
}
|
||||
|
||||
expected, err := time.Parse(time.RFC3339, ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse test time literal: %v", err)
|
||||
}
|
||||
|
||||
if !result.Equal(expected) {
|
||||
t.Fatalf("Expected time to match %v: got mismatch %v",
|
||||
expected, result)
|
||||
}
|
||||
|
||||
n = r.Next()
|
||||
if n != false {
|
||||
t.Fatal("Expected only one row")
|
||||
}
|
||||
}
|
||||
|
||||
test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00")
|
||||
|
||||
// Test higher precision time
|
||||
test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033")
|
||||
}
|
||||
|
||||
func TestInfinityTimestamp(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
var err error
|
||||
var resultT time.Time
|
||||
|
||||
expectedErrorStrPrefix := `sql: Scan error on column index 0: unsupported`
|
||||
type testCases []struct {
|
||||
Query string
|
||||
Param string
|
||||
ExpectedErrStrPrefix string
|
||||
ExpectedVal interface{}
|
||||
}
|
||||
tc := testCases{
|
||||
{"SELECT $1::timestamp", "-infinity", expectedErrorStrPrefix, "-infinity"},
|
||||
{"SELECT $1::timestamptz", "-infinity", expectedErrorStrPrefix, "-infinity"},
|
||||
{"SELECT $1::timestamp", "infinity", expectedErrorStrPrefix, "infinity"},
|
||||
{"SELECT $1::timestamptz", "infinity", expectedErrorStrPrefix, "infinity"},
|
||||
}
|
||||
// try to assert []byte to time.Time
|
||||
for _, q := range tc {
|
||||
err = db.QueryRow(q.Query, q.Param).Scan(&resultT)
|
||||
if !strings.HasPrefix(err.Error(), q.ExpectedErrStrPrefix) {
|
||||
t.Errorf("Scanning -/+infinity, expected error to have prefix %q, got %q", q.ExpectedErrStrPrefix, err)
|
||||
}
|
||||
}
|
||||
// yield []byte
|
||||
for _, q := range tc {
|
||||
var resultI interface{}
|
||||
err = db.QueryRow(q.Query, q.Param).Scan(&resultI)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning -/+infinity, expected no error, got %q", err)
|
||||
}
|
||||
result, ok := resultI.([]byte)
|
||||
if !ok {
|
||||
t.Errorf("Scanning -/+infinity, expected []byte, got %#v", resultI)
|
||||
}
|
||||
if string(result) != q.ExpectedVal {
|
||||
t.Errorf("Scanning -/+infinity, expected %q, got %q", q.ExpectedVal, result)
|
||||
}
|
||||
}
|
||||
|
||||
y1500 := time.Date(1500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
y2500 := time.Date(2500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
EnableInfinityTs(y1500, y2500)
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamp", "infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y2500) {
|
||||
t.Errorf("Scanning infinity, expected %q, got %q", y2500, resultT)
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamptz", "infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y2500) {
|
||||
t.Errorf("Scanning Infinity, expected time %q, got %q", y2500, resultT.String())
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamp", "-infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y1500) {
|
||||
t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String())
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamptz", "-infinity").Scan(&resultT)
|
||||
if err != nil {
|
||||
t.Errorf("Scanning -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if !resultT.Equal(y1500) {
|
||||
t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String())
|
||||
}
|
||||
|
||||
y_1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
var s string
|
||||
err = db.QueryRow("SELECT $1::timestamp::text", y_1500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "-infinity" {
|
||||
t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s)
|
||||
}
|
||||
err = db.QueryRow("SELECT $1::timestamptz::text", y_1500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding -infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "-infinity" {
|
||||
t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s)
|
||||
}
|
||||
|
||||
err = db.QueryRow("SELECT $1::timestamp::text", y11500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "infinity" {
|
||||
t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s)
|
||||
}
|
||||
err = db.QueryRow("SELECT $1::timestamptz::text", y11500).Scan(&s)
|
||||
if err != nil {
|
||||
t.Errorf("Encoding infinity, expected no error, got %q", err)
|
||||
}
|
||||
if s != "infinity" {
|
||||
t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s)
|
||||
}
|
||||
|
||||
disableInfinityTs()
|
||||
|
||||
var panicErrorString string
|
||||
func() {
|
||||
defer func() {
|
||||
panicErrorString, _ = recover().(string)
|
||||
}()
|
||||
EnableInfinityTs(y2500, y1500)
|
||||
}()
|
||||
if panicErrorString != infinityTsNegativeMustBeSmaller {
|
||||
t.Errorf("Expected error, %q, got %q", infinityTsNegativeMustBeSmaller, panicErrorString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringWithNul(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
hello0world := string("hello\x00world")
|
||||
_, err := db.Query("SELECT $1::text", &hello0world)
|
||||
if err == nil {
|
||||
t.Fatal("Postgres accepts a string with nul in it; " +
|
||||
"injection attacks may be plausible")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteSliceToText(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte("hello world")
|
||||
row := db.QueryRow("SELECT $1::text", b)
|
||||
|
||||
var result []byte
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(result) != string(b) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToBytea(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := "hello world"
|
||||
row := db.QueryRow("SELECT $1::bytea", b)
|
||||
|
||||
var result []byte
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(result, []byte(b)) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextByteSliceToUUID(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
|
||||
row := db.QueryRow("SELECT $1::uuid", b)
|
||||
|
||||
var result string
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22P03" {
|
||||
t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != string(b) {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBinaryByteSlicetoUUID(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
b := []byte{'\xa0', '\xee', '\xbc', '\x99',
|
||||
'\x9c', '\x0b',
|
||||
'\x4e', '\xf8',
|
||||
'\xbb', '\x00', '\x6b',
|
||||
'\xb9', '\xbd', '\x38', '\x0a', '\x11'}
|
||||
row := db.QueryRow("SELECT $1::uuid", b)
|
||||
|
||||
var result string
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != string("a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11") {
|
||||
t.Fatalf("expected %v but got %v", b, result)
|
||||
}
|
||||
} else {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22021" {
|
||||
t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringToUUID(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
s := "a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11"
|
||||
row := db.QueryRow("SELECT $1::uuid", s)
|
||||
|
||||
var result string
|
||||
err := row.Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if result != s {
|
||||
t.Fatalf("expected %v but got %v", s, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextByteSliceToInt(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
expected := 12345678
|
||||
b := []byte(fmt.Sprintf("%d", expected))
|
||||
row := db.QueryRow("SELECT $1::int", b)
|
||||
|
||||
var result int
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22P03" {
|
||||
t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != expected {
|
||||
t.Fatalf("expected %v but got %v", expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBinaryByteSliceToInt(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
expected := 12345678
|
||||
b := []byte{'\x00', '\xbc', '\x61', '\x4e'}
|
||||
row := db.QueryRow("SELECT $1::int", b)
|
||||
|
||||
var result int
|
||||
err := row.Scan(&result)
|
||||
if forceBinaryParameters() {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != expected {
|
||||
t.Fatalf("expected %v but got %v", expected, result)
|
||||
}
|
||||
} else {
|
||||
pqErr := err.(*Error)
|
||||
if pqErr == nil {
|
||||
t.Errorf("Expected to get error")
|
||||
} else if pqErr.Code != "22021" {
|
||||
t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextDecodeIntoString(t *testing.T) {
|
||||
input := []byte("hello world")
|
||||
want := string(input)
|
||||
for _, typ := range []oid.Oid{oid.T_char, oid.T_varchar, oid.T_text} {
|
||||
got := decode(¶meterStatus{}, input, typ, formatText)
|
||||
if got != want {
|
||||
t.Errorf("invalid string decoding output for %T(%+v), got %v but expected %v", typ, typ, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteaOutputFormatEncoding(t *testing.T) {
|
||||
input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123")
|
||||
want := []byte("\\x5c78000102fffe6162636465666730313233")
|
||||
got := encode(¶meterStatus{serverVersion: 90000}, input, oid.T_bytea)
|
||||
if !bytes.Equal(want, got) {
|
||||
t.Errorf("invalid hex bytea output, got %v but expected %v", got, want)
|
||||
}
|
||||
|
||||
want = []byte("\\\\x\\000\\001\\002\\377\\376abcdefg0123")
|
||||
got = encode(¶meterStatus{serverVersion: 84000}, input, oid.T_bytea)
|
||||
if !bytes.Equal(want, got) {
|
||||
t.Errorf("invalid escape bytea output, got %v but expected %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteaOutputFormats(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
if getServerVersion(t, db) < 90000 {
|
||||
// skip
|
||||
return
|
||||
}
|
||||
|
||||
testByteaOutputFormat := func(f string, usePrepared bool) {
|
||||
expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08")
|
||||
sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')"
|
||||
|
||||
var data []byte
|
||||
|
||||
// use a txn to avoid relying on getting the same connection
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
|
||||
_, err = txn.Exec("SET LOCAL bytea_output TO " + f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var rows *sql.Rows
|
||||
var stmt *sql.Stmt
|
||||
if usePrepared {
|
||||
stmt, err = txn.Prepare(sqlQuery)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rows, err = stmt.Query()
|
||||
} else {
|
||||
// use Query; QueryRow would hide the actual error
|
||||
rows, err = txn.Query(sqlQuery)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !rows.Next() {
|
||||
if rows.Err() != nil {
|
||||
t.Fatal(rows.Err())
|
||||
}
|
||||
t.Fatal("shouldn't happen")
|
||||
}
|
||||
err = rows.Scan(&data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt != nil {
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(data, expectedData) {
|
||||
t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData)
|
||||
}
|
||||
}
|
||||
|
||||
testByteaOutputFormat("hex", false)
|
||||
testByteaOutputFormat("escape", false)
|
||||
testByteaOutputFormat("hex", true)
|
||||
testByteaOutputFormat("escape", true)
|
||||
}
|
||||
|
||||
func TestAppendEncodedText(t *testing.T) {
|
||||
var buf []byte
|
||||
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, int64(10))
|
||||
buf = append(buf, '\t')
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, 42.0000000001)
|
||||
buf = append(buf, '\t')
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, "hello\tworld")
|
||||
buf = append(buf, '\t')
|
||||
buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, []byte{0, 128, 255})
|
||||
|
||||
if string(buf) != "10\t42.0000000001\thello\\tworld\t\\\\x0080ff" {
|
||||
t.Fatal(string(buf))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendEscapedText(t *testing.T) {
|
||||
if esc := appendEscapedText(nil, "hallo\tescape"); string(esc) != "hallo\\tescape" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
if esc := appendEscapedText(nil, "hallo\\tescape\n"); string(esc) != "hallo\\\\tescape\\n" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
if esc := appendEscapedText(nil, "\n\r\t\f"); string(esc) != "\\n\\r\\t\f" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendEscapedTextExistingBuffer(t *testing.T) {
|
||||
var buf []byte
|
||||
buf = []byte("123\t")
|
||||
if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
buf = []byte("123\t")
|
||||
if esc := appendEscapedText(buf, "hallo\\tescape\n"); string(esc) != "123\thallo\\\\tescape\\n" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
buf = []byte("123\t")
|
||||
if esc := appendEscapedText(buf, "\n\r\t\f"); string(esc) != "123\t\\n\\r\\t\f" {
|
||||
t.Fatal(string(esc))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendEscapedText(b *testing.B) {
|
||||
longString := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
longString += "123456789\n"
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
appendEscapedText(nil, longString)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendEscapedTextNoEscape(b *testing.B) {
|
||||
longString := ""
|
||||
for i := 0; i < 100; i++ {
|
||||
longString += "1234567890"
|
||||
}
|
||||
for i := 0; i < b.N; i++ {
|
||||
appendEscapedText(nil, longString)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,508 @@
|
|||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// Error severities
|
||||
const (
|
||||
Efatal = "FATAL"
|
||||
Epanic = "PANIC"
|
||||
Ewarning = "WARNING"
|
||||
Enotice = "NOTICE"
|
||||
Edebug = "DEBUG"
|
||||
Einfo = "INFO"
|
||||
Elog = "LOG"
|
||||
)
|
||||
|
||||
// Error represents an error communicating with the server.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields
|
||||
type Error struct {
|
||||
Severity string
|
||||
Code ErrorCode
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position string
|
||||
InternalPosition string
|
||||
InternalQuery string
|
||||
Where string
|
||||
Schema string
|
||||
Table string
|
||||
Column string
|
||||
DataTypeName string
|
||||
Constraint string
|
||||
File string
|
||||
Line string
|
||||
Routine string
|
||||
}
|
||||
|
||||
// ErrorCode is a five-character error code.
|
||||
type ErrorCode string
|
||||
|
||||
// Name returns a more human friendly rendering of the error code, namely the
|
||||
// "condition name".
|
||||
//
|
||||
// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for
|
||||
// details.
|
||||
func (ec ErrorCode) Name() string {
|
||||
return errorCodeNames[ec]
|
||||
}
|
||||
|
||||
// ErrorClass is only the class part of an error code.
|
||||
type ErrorClass string
|
||||
|
||||
// Name returns the condition name of an error class. It is equivalent to the
|
||||
// condition name of the "standard" error code (i.e. the one having the last
|
||||
// three characters "000").
|
||||
func (ec ErrorClass) Name() string {
|
||||
return errorCodeNames[ErrorCode(ec+"000")]
|
||||
}
|
||||
|
||||
// Class returns the error class, e.g. "28".
|
||||
//
|
||||
// See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for
|
||||
// details.
|
||||
func (ec ErrorCode) Class() ErrorClass {
|
||||
return ErrorClass(ec[0:2])
|
||||
}
|
||||
|
||||
// errorCodeNames is a mapping between the five-character error codes and the
|
||||
// human readable "condition names". It is derived from the list at
|
||||
// http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html
|
||||
var errorCodeNames = map[ErrorCode]string{
|
||||
// Class 00 - Successful Completion
|
||||
"00000": "successful_completion",
|
||||
// Class 01 - Warning
|
||||
"01000": "warning",
|
||||
"0100C": "dynamic_result_sets_returned",
|
||||
"01008": "implicit_zero_bit_padding",
|
||||
"01003": "null_value_eliminated_in_set_function",
|
||||
"01007": "privilege_not_granted",
|
||||
"01006": "privilege_not_revoked",
|
||||
"01004": "string_data_right_truncation",
|
||||
"01P01": "deprecated_feature",
|
||||
// Class 02 - No Data (this is also a warning class per the SQL standard)
|
||||
"02000": "no_data",
|
||||
"02001": "no_additional_dynamic_result_sets_returned",
|
||||
// Class 03 - SQL Statement Not Yet Complete
|
||||
"03000": "sql_statement_not_yet_complete",
|
||||
// Class 08 - Connection Exception
|
||||
"08000": "connection_exception",
|
||||
"08003": "connection_does_not_exist",
|
||||
"08006": "connection_failure",
|
||||
"08001": "sqlclient_unable_to_establish_sqlconnection",
|
||||
"08004": "sqlserver_rejected_establishment_of_sqlconnection",
|
||||
"08007": "transaction_resolution_unknown",
|
||||
"08P01": "protocol_violation",
|
||||
// Class 09 - Triggered Action Exception
|
||||
"09000": "triggered_action_exception",
|
||||
// Class 0A - Feature Not Supported
|
||||
"0A000": "feature_not_supported",
|
||||
// Class 0B - Invalid Transaction Initiation
|
||||
"0B000": "invalid_transaction_initiation",
|
||||
// Class 0F - Locator Exception
|
||||
"0F000": "locator_exception",
|
||||
"0F001": "invalid_locator_specification",
|
||||
// Class 0L - Invalid Grantor
|
||||
"0L000": "invalid_grantor",
|
||||
"0LP01": "invalid_grant_operation",
|
||||
// Class 0P - Invalid Role Specification
|
||||
"0P000": "invalid_role_specification",
|
||||
// Class 0Z - Diagnostics Exception
|
||||
"0Z000": "diagnostics_exception",
|
||||
"0Z002": "stacked_diagnostics_accessed_without_active_handler",
|
||||
// Class 20 - Case Not Found
|
||||
"20000": "case_not_found",
|
||||
// Class 21 - Cardinality Violation
|
||||
"21000": "cardinality_violation",
|
||||
// Class 22 - Data Exception
|
||||
"22000": "data_exception",
|
||||
"2202E": "array_subscript_error",
|
||||
"22021": "character_not_in_repertoire",
|
||||
"22008": "datetime_field_overflow",
|
||||
"22012": "division_by_zero",
|
||||
"22005": "error_in_assignment",
|
||||
"2200B": "escape_character_conflict",
|
||||
"22022": "indicator_overflow",
|
||||
"22015": "interval_field_overflow",
|
||||
"2201E": "invalid_argument_for_logarithm",
|
||||
"22014": "invalid_argument_for_ntile_function",
|
||||
"22016": "invalid_argument_for_nth_value_function",
|
||||
"2201F": "invalid_argument_for_power_function",
|
||||
"2201G": "invalid_argument_for_width_bucket_function",
|
||||
"22018": "invalid_character_value_for_cast",
|
||||
"22007": "invalid_datetime_format",
|
||||
"22019": "invalid_escape_character",
|
||||
"2200D": "invalid_escape_octet",
|
||||
"22025": "invalid_escape_sequence",
|
||||
"22P06": "nonstandard_use_of_escape_character",
|
||||
"22010": "invalid_indicator_parameter_value",
|
||||
"22023": "invalid_parameter_value",
|
||||
"2201B": "invalid_regular_expression",
|
||||
"2201W": "invalid_row_count_in_limit_clause",
|
||||
"2201X": "invalid_row_count_in_result_offset_clause",
|
||||
"22009": "invalid_time_zone_displacement_value",
|
||||
"2200C": "invalid_use_of_escape_character",
|
||||
"2200G": "most_specific_type_mismatch",
|
||||
"22004": "null_value_not_allowed",
|
||||
"22002": "null_value_no_indicator_parameter",
|
||||
"22003": "numeric_value_out_of_range",
|
||||
"22026": "string_data_length_mismatch",
|
||||
"22001": "string_data_right_truncation",
|
||||
"22011": "substring_error",
|
||||
"22027": "trim_error",
|
||||
"22024": "unterminated_c_string",
|
||||
"2200F": "zero_length_character_string",
|
||||
"22P01": "floating_point_exception",
|
||||
"22P02": "invalid_text_representation",
|
||||
"22P03": "invalid_binary_representation",
|
||||
"22P04": "bad_copy_file_format",
|
||||
"22P05": "untranslatable_character",
|
||||
"2200L": "not_an_xml_document",
|
||||
"2200M": "invalid_xml_document",
|
||||
"2200N": "invalid_xml_content",
|
||||
"2200S": "invalid_xml_comment",
|
||||
"2200T": "invalid_xml_processing_instruction",
|
||||
// Class 23 - Integrity Constraint Violation
|
||||
"23000": "integrity_constraint_violation",
|
||||
"23001": "restrict_violation",
|
||||
"23502": "not_null_violation",
|
||||
"23503": "foreign_key_violation",
|
||||
"23505": "unique_violation",
|
||||
"23514": "check_violation",
|
||||
"23P01": "exclusion_violation",
|
||||
// Class 24 - Invalid Cursor State
|
||||
"24000": "invalid_cursor_state",
|
||||
// Class 25 - Invalid Transaction State
|
||||
"25000": "invalid_transaction_state",
|
||||
"25001": "active_sql_transaction",
|
||||
"25002": "branch_transaction_already_active",
|
||||
"25008": "held_cursor_requires_same_isolation_level",
|
||||
"25003": "inappropriate_access_mode_for_branch_transaction",
|
||||
"25004": "inappropriate_isolation_level_for_branch_transaction",
|
||||
"25005": "no_active_sql_transaction_for_branch_transaction",
|
||||
"25006": "read_only_sql_transaction",
|
||||
"25007": "schema_and_data_statement_mixing_not_supported",
|
||||
"25P01": "no_active_sql_transaction",
|
||||
"25P02": "in_failed_sql_transaction",
|
||||
// Class 26 - Invalid SQL Statement Name
|
||||
"26000": "invalid_sql_statement_name",
|
||||
// Class 27 - Triggered Data Change Violation
|
||||
"27000": "triggered_data_change_violation",
|
||||
// Class 28 - Invalid Authorization Specification
|
||||
"28000": "invalid_authorization_specification",
|
||||
"28P01": "invalid_password",
|
||||
// Class 2B - Dependent Privilege Descriptors Still Exist
|
||||
"2B000": "dependent_privilege_descriptors_still_exist",
|
||||
"2BP01": "dependent_objects_still_exist",
|
||||
// Class 2D - Invalid Transaction Termination
|
||||
"2D000": "invalid_transaction_termination",
|
||||
// Class 2F - SQL Routine Exception
|
||||
"2F000": "sql_routine_exception",
|
||||
"2F005": "function_executed_no_return_statement",
|
||||
"2F002": "modifying_sql_data_not_permitted",
|
||||
"2F003": "prohibited_sql_statement_attempted",
|
||||
"2F004": "reading_sql_data_not_permitted",
|
||||
// Class 34 - Invalid Cursor Name
|
||||
"34000": "invalid_cursor_name",
|
||||
// Class 38 - External Routine Exception
|
||||
"38000": "external_routine_exception",
|
||||
"38001": "containing_sql_not_permitted",
|
||||
"38002": "modifying_sql_data_not_permitted",
|
||||
"38003": "prohibited_sql_statement_attempted",
|
||||
"38004": "reading_sql_data_not_permitted",
|
||||
// Class 39 - External Routine Invocation Exception
|
||||
"39000": "external_routine_invocation_exception",
|
||||
"39001": "invalid_sqlstate_returned",
|
||||
"39004": "null_value_not_allowed",
|
||||
"39P01": "trigger_protocol_violated",
|
||||
"39P02": "srf_protocol_violated",
|
||||
// Class 3B - Savepoint Exception
|
||||
"3B000": "savepoint_exception",
|
||||
"3B001": "invalid_savepoint_specification",
|
||||
// Class 3D - Invalid Catalog Name
|
||||
"3D000": "invalid_catalog_name",
|
||||
// Class 3F - Invalid Schema Name
|
||||
"3F000": "invalid_schema_name",
|
||||
// Class 40 - Transaction Rollback
|
||||
"40000": "transaction_rollback",
|
||||
"40002": "transaction_integrity_constraint_violation",
|
||||
"40001": "serialization_failure",
|
||||
"40003": "statement_completion_unknown",
|
||||
"40P01": "deadlock_detected",
|
||||
// Class 42 - Syntax Error or Access Rule Violation
|
||||
"42000": "syntax_error_or_access_rule_violation",
|
||||
"42601": "syntax_error",
|
||||
"42501": "insufficient_privilege",
|
||||
"42846": "cannot_coerce",
|
||||
"42803": "grouping_error",
|
||||
"42P20": "windowing_error",
|
||||
"42P19": "invalid_recursion",
|
||||
"42830": "invalid_foreign_key",
|
||||
"42602": "invalid_name",
|
||||
"42622": "name_too_long",
|
||||
"42939": "reserved_name",
|
||||
"42804": "datatype_mismatch",
|
||||
"42P18": "indeterminate_datatype",
|
||||
"42P21": "collation_mismatch",
|
||||
"42P22": "indeterminate_collation",
|
||||
"42809": "wrong_object_type",
|
||||
"42703": "undefined_column",
|
||||
"42883": "undefined_function",
|
||||
"42P01": "undefined_table",
|
||||
"42P02": "undefined_parameter",
|
||||
"42704": "undefined_object",
|
||||
"42701": "duplicate_column",
|
||||
"42P03": "duplicate_cursor",
|
||||
"42P04": "duplicate_database",
|
||||
"42723": "duplicate_function",
|
||||
"42P05": "duplicate_prepared_statement",
|
||||
"42P06": "duplicate_schema",
|
||||
"42P07": "duplicate_table",
|
||||
"42712": "duplicate_alias",
|
||||
"42710": "duplicate_object",
|
||||
"42702": "ambiguous_column",
|
||||
"42725": "ambiguous_function",
|
||||
"42P08": "ambiguous_parameter",
|
||||
"42P09": "ambiguous_alias",
|
||||
"42P10": "invalid_column_reference",
|
||||
"42611": "invalid_column_definition",
|
||||
"42P11": "invalid_cursor_definition",
|
||||
"42P12": "invalid_database_definition",
|
||||
"42P13": "invalid_function_definition",
|
||||
"42P14": "invalid_prepared_statement_definition",
|
||||
"42P15": "invalid_schema_definition",
|
||||
"42P16": "invalid_table_definition",
|
||||
"42P17": "invalid_object_definition",
|
||||
// Class 44 - WITH CHECK OPTION Violation
|
||||
"44000": "with_check_option_violation",
|
||||
// Class 53 - Insufficient Resources
|
||||
"53000": "insufficient_resources",
|
||||
"53100": "disk_full",
|
||||
"53200": "out_of_memory",
|
||||
"53300": "too_many_connections",
|
||||
"53400": "configuration_limit_exceeded",
|
||||
// Class 54 - Program Limit Exceeded
|
||||
"54000": "program_limit_exceeded",
|
||||
"54001": "statement_too_complex",
|
||||
"54011": "too_many_columns",
|
||||
"54023": "too_many_arguments",
|
||||
// Class 55 - Object Not In Prerequisite State
|
||||
"55000": "object_not_in_prerequisite_state",
|
||||
"55006": "object_in_use",
|
||||
"55P02": "cant_change_runtime_param",
|
||||
"55P03": "lock_not_available",
|
||||
// Class 57 - Operator Intervention
|
||||
"57000": "operator_intervention",
|
||||
"57014": "query_canceled",
|
||||
"57P01": "admin_shutdown",
|
||||
"57P02": "crash_shutdown",
|
||||
"57P03": "cannot_connect_now",
|
||||
"57P04": "database_dropped",
|
||||
// Class 58 - System Error (errors external to PostgreSQL itself)
|
||||
"58000": "system_error",
|
||||
"58030": "io_error",
|
||||
"58P01": "undefined_file",
|
||||
"58P02": "duplicate_file",
|
||||
// Class F0 - Configuration File Error
|
||||
"F0000": "config_file_error",
|
||||
"F0001": "lock_file_exists",
|
||||
// Class HV - Foreign Data Wrapper Error (SQL/MED)
|
||||
"HV000": "fdw_error",
|
||||
"HV005": "fdw_column_name_not_found",
|
||||
"HV002": "fdw_dynamic_parameter_value_needed",
|
||||
"HV010": "fdw_function_sequence_error",
|
||||
"HV021": "fdw_inconsistent_descriptor_information",
|
||||
"HV024": "fdw_invalid_attribute_value",
|
||||
"HV007": "fdw_invalid_column_name",
|
||||
"HV008": "fdw_invalid_column_number",
|
||||
"HV004": "fdw_invalid_data_type",
|
||||
"HV006": "fdw_invalid_data_type_descriptors",
|
||||
"HV091": "fdw_invalid_descriptor_field_identifier",
|
||||
"HV00B": "fdw_invalid_handle",
|
||||
"HV00C": "fdw_invalid_option_index",
|
||||
"HV00D": "fdw_invalid_option_name",
|
||||
"HV090": "fdw_invalid_string_length_or_buffer_length",
|
||||
"HV00A": "fdw_invalid_string_format",
|
||||
"HV009": "fdw_invalid_use_of_null_pointer",
|
||||
"HV014": "fdw_too_many_handles",
|
||||
"HV001": "fdw_out_of_memory",
|
||||
"HV00P": "fdw_no_schemas",
|
||||
"HV00J": "fdw_option_name_not_found",
|
||||
"HV00K": "fdw_reply_handle",
|
||||
"HV00Q": "fdw_schema_not_found",
|
||||
"HV00R": "fdw_table_not_found",
|
||||
"HV00L": "fdw_unable_to_create_execution",
|
||||
"HV00M": "fdw_unable_to_create_reply",
|
||||
"HV00N": "fdw_unable_to_establish_connection",
|
||||
// Class P0 - PL/pgSQL Error
|
||||
"P0000": "plpgsql_error",
|
||||
"P0001": "raise_exception",
|
||||
"P0002": "no_data_found",
|
||||
"P0003": "too_many_rows",
|
||||
// Class XX - Internal Error
|
||||
"XX000": "internal_error",
|
||||
"XX001": "data_corrupted",
|
||||
"XX002": "index_corrupted",
|
||||
}
|
||||
|
||||
func parseError(r *readBuf) *Error {
|
||||
err := new(Error)
|
||||
for t := r.byte(); t != 0; t = r.byte() {
|
||||
msg := r.string()
|
||||
switch t {
|
||||
case 'S':
|
||||
err.Severity = msg
|
||||
case 'C':
|
||||
err.Code = ErrorCode(msg)
|
||||
case 'M':
|
||||
err.Message = msg
|
||||
case 'D':
|
||||
err.Detail = msg
|
||||
case 'H':
|
||||
err.Hint = msg
|
||||
case 'P':
|
||||
err.Position = msg
|
||||
case 'p':
|
||||
err.InternalPosition = msg
|
||||
case 'q':
|
||||
err.InternalQuery = msg
|
||||
case 'W':
|
||||
err.Where = msg
|
||||
case 's':
|
||||
err.Schema = msg
|
||||
case 't':
|
||||
err.Table = msg
|
||||
case 'c':
|
||||
err.Column = msg
|
||||
case 'd':
|
||||
err.DataTypeName = msg
|
||||
case 'n':
|
||||
err.Constraint = msg
|
||||
case 'F':
|
||||
err.File = msg
|
||||
case 'L':
|
||||
err.Line = msg
|
||||
case 'R':
|
||||
err.Routine = msg
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Fatal returns true if the Error Severity is fatal.
|
||||
func (err *Error) Fatal() bool {
|
||||
return err.Severity == Efatal
|
||||
}
|
||||
|
||||
// Get implements the legacy PGError interface. New code should use the fields
|
||||
// of the Error struct directly.
|
||||
func (err *Error) Get(k byte) (v string) {
|
||||
switch k {
|
||||
case 'S':
|
||||
return err.Severity
|
||||
case 'C':
|
||||
return string(err.Code)
|
||||
case 'M':
|
||||
return err.Message
|
||||
case 'D':
|
||||
return err.Detail
|
||||
case 'H':
|
||||
return err.Hint
|
||||
case 'P':
|
||||
return err.Position
|
||||
case 'p':
|
||||
return err.InternalPosition
|
||||
case 'q':
|
||||
return err.InternalQuery
|
||||
case 'W':
|
||||
return err.Where
|
||||
case 's':
|
||||
return err.Schema
|
||||
case 't':
|
||||
return err.Table
|
||||
case 'c':
|
||||
return err.Column
|
||||
case 'd':
|
||||
return err.DataTypeName
|
||||
case 'n':
|
||||
return err.Constraint
|
||||
case 'F':
|
||||
return err.File
|
||||
case 'L':
|
||||
return err.Line
|
||||
case 'R':
|
||||
return err.Routine
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (err Error) Error() string {
|
||||
return "pq: " + err.Message
|
||||
}
|
||||
|
||||
// PGError is an interface used by previous versions of pq. It is provided
|
||||
// only to support legacy code. New code should use the Error type.
|
||||
type PGError interface {
|
||||
Error() string
|
||||
Fatal() bool
|
||||
Get(k byte) (v string)
|
||||
}
|
||||
|
||||
func errorf(s string, args ...interface{}) {
|
||||
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
|
||||
}
|
||||
|
||||
func errRecoverNoErrBadConn(err *error) {
|
||||
e := recover()
|
||||
if e == nil {
|
||||
// Do nothing
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
*err, ok = e.(error)
|
||||
if !ok {
|
||||
*err = fmt.Errorf("pq: unexpected error: %#v", e)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) errRecover(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case runtime.Error:
|
||||
c.bad = true
|
||||
panic(v)
|
||||
case *Error:
|
||||
if v.Fatal() {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
case *net.OpError:
|
||||
*err = driver.ErrBadConn
|
||||
case error:
|
||||
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
|
||||
default:
|
||||
c.bad = true
|
||||
panic(fmt.Sprintf("unknown error: %#v", e))
|
||||
}
|
||||
|
||||
// Any time we return ErrBadConn, we need to remember it since *Tx doesn't
|
||||
// mark the connection bad in database/sql.
|
||||
if *err == driver.ErrBadConn {
|
||||
c.bad = true
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
// +build go1.8
|
||||
|
||||
package pq
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMultipleSimpleQuery(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
rows, err := db.Query("select 1; set time zone default; select 2; select 3")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var i int
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 1 {
|
||||
t.Fatalf("expected 1, got %d", i)
|
||||
}
|
||||
}
|
||||
if !rows.NextResultSet() {
|
||||
t.Fatal("expected more result sets", rows.Err())
|
||||
}
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 2 {
|
||||
t.Fatalf("expected 2, got %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure that if we ignore a result we can still query.
|
||||
|
||||
rows, err = db.Query("select 4; select 5")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 4 {
|
||||
t.Fatalf("expected 4, got %d", i)
|
||||
}
|
||||
}
|
||||
if !rows.NextResultSet() {
|
||||
t.Fatal("expected more result sets", rows.Err())
|
||||
}
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&i); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if i != 5 {
|
||||
t.Fatalf("expected 5, got %d", i)
|
||||
}
|
||||
}
|
||||
if rows.NextResultSet() {
|
||||
t.Fatal("unexpected result set")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
package pq
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIssue494(t *testing.T) {
|
||||
db := openTestConn(t)
|
||||
defer db.Close()
|
||||
|
||||
query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)`
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
txn, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := txn.Prepare(CopyIn("t", "i")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := txn.Query("SELECT 1"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,782 @@
|
|||
package pq
|
||||
|
||||
// Package pq is a pure Go Postgres driver for the database/sql package.
|
||||
// This module contains support for Postgres LISTEN/NOTIFY.
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Notification represents a single notification from the database.
|
||||
type Notification struct {
|
||||
// Process ID (PID) of the notifying postgres backend.
|
||||
BePid int
|
||||
// Name of the channel the notification was sent on.
|
||||
Channel string
|
||||
// Payload, or the empty string if unspecified.
|
||||
Extra string
|
||||
}
|
||||
|
||||
func recvNotification(r *readBuf) *Notification {
|
||||
bePid := r.int32()
|
||||
channel := r.string()
|
||||
extra := r.string()
|
||||
|
||||
return &Notification{bePid, channel, extra}
|
||||
}
|
||||
|
||||
const (
|
||||
connStateIdle int32 = iota
|
||||
connStateExpectResponse
|
||||
connStateExpectReadyForQuery
|
||||
)
|
||||
|
||||
type message struct {
|
||||
typ byte
|
||||
err error
|
||||
}
|
||||
|
||||
var errListenerConnClosed = errors.New("pq: ListenerConn has been closed")
|
||||
|
||||
// ListenerConn is a low-level interface for waiting for notifications. You
|
||||
// should use Listener instead.
|
||||
type ListenerConn struct {
|
||||
// guards cn and err
|
||||
connectionLock sync.Mutex
|
||||
cn *conn
|
||||
err error
|
||||
|
||||
connState int32
|
||||
|
||||
// the sending goroutine will be holding this lock
|
||||
senderLock sync.Mutex
|
||||
|
||||
notificationChan chan<- *Notification
|
||||
|
||||
replyChan chan message
|
||||
}
|
||||
|
||||
// Creates a new ListenerConn. Use NewListener instead.
|
||||
func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) {
|
||||
return newDialListenerConn(defaultDialer{}, name, notificationChan)
|
||||
}
|
||||
|
||||
func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) {
|
||||
cn, err := DialOpen(d, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := &ListenerConn{
|
||||
cn: cn.(*conn),
|
||||
notificationChan: c,
|
||||
connState: connStateIdle,
|
||||
replyChan: make(chan message, 2),
|
||||
}
|
||||
|
||||
go l.listenerConnMain()
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// We can only allow one goroutine at a time to be running a query on the
|
||||
// connection for various reasons, so the goroutine sending on the connection
|
||||
// must be holding senderLock.
|
||||
//
|
||||
// Returns an error if an unrecoverable error has occurred and the ListenerConn
|
||||
// should be abandoned.
|
||||
func (l *ListenerConn) acquireSenderLock() error {
|
||||
// we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery
|
||||
l.senderLock.Lock()
|
||||
|
||||
l.connectionLock.Lock()
|
||||
err := l.err
|
||||
l.connectionLock.Unlock()
|
||||
if err != nil {
|
||||
l.senderLock.Unlock()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *ListenerConn) releaseSenderLock() {
|
||||
l.senderLock.Unlock()
|
||||
}
|
||||
|
||||
// setState advances the protocol state to newState. Returns false if moving
|
||||
// to that state from the current state is not allowed.
|
||||
func (l *ListenerConn) setState(newState int32) bool {
|
||||
var expectedState int32
|
||||
|
||||
switch newState {
|
||||
case connStateIdle:
|
||||
expectedState = connStateExpectReadyForQuery
|
||||
case connStateExpectResponse:
|
||||
expectedState = connStateIdle
|
||||
case connStateExpectReadyForQuery:
|
||||
expectedState = connStateExpectResponse
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected listenerConnState %d", newState))
|
||||
}
|
||||
|
||||
return atomic.CompareAndSwapInt32(&l.connState, expectedState, newState)
|
||||
}
|
||||
|
||||
// Main logic is here: receive messages from the postgres backend, forward
|
||||
// notifications and query replies and keep the internal state in sync with the
|
||||
// protocol state. Returns when the connection has been lost, is about to go
|
||||
// away or should be discarded because we couldn't agree on the state with the
|
||||
// server backend.
|
||||
func (l *ListenerConn) listenerConnLoop() (err error) {
|
||||
defer errRecoverNoErrBadConn(&err)
|
||||
|
||||
r := &readBuf{}
|
||||
for {
|
||||
t, err := l.cn.recvMessage(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case 'A':
|
||||
// recvNotification copies all the data so we don't need to worry
|
||||
// about the scratch buffer being overwritten.
|
||||
l.notificationChan <- recvNotification(r)
|
||||
|
||||
case 'T', 'D':
|
||||
// only used by tests; ignore
|
||||
|
||||
case 'E':
|
||||
// We might receive an ErrorResponse even when not in a query; it
|
||||
// is expected that the server will close the connection after
|
||||
// that, but we should make sure that the error we display is the
|
||||
// one from the stray ErrorResponse, not io.ErrUnexpectedEOF.
|
||||
if !l.setState(connStateExpectReadyForQuery) {
|
||||
return parseError(r)
|
||||
}
|
||||
l.replyChan <- message{t, parseError(r)}
|
||||
|
||||
case 'C', 'I':
|
||||
if !l.setState(connStateExpectReadyForQuery) {
|
||||
// protocol out of sync
|
||||
return fmt.Errorf("unexpected CommandComplete")
|
||||
}
|
||||
// ExecSimpleQuery doesn't need to know about this message
|
||||
|
||||
case 'Z':
|
||||
if !l.setState(connStateIdle) {
|
||||
// protocol out of sync
|
||||
return fmt.Errorf("unexpected ReadyForQuery")
|
||||
}
|
||||
l.replyChan <- message{t, nil}
|
||||
|
||||
case 'N', 'S':
|
||||
// ignore
|
||||
default:
|
||||
return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is the main routine for the goroutine receiving on the database
|
||||
// connection. Most of the main logic is in listenerConnLoop.
|
||||
func (l *ListenerConn) listenerConnMain() {
|
||||
err := l.listenerConnLoop()
|
||||
|
||||
// listenerConnLoop terminated; we're done, but we still have to clean up.
|
||||
// Make sure nobody tries to start any new queries by making sure the err
|
||||
// pointer is set. It is important that we do not overwrite its value; a
|
||||
// connection could be closed by either this goroutine or one sending on
|
||||
// the connection -- whoever closes the connection is assumed to have the
|
||||
// more meaningful error message (as the other one will probably get
|
||||
// net.errClosed), so that goroutine sets the error we expose while the
|
||||
// other error is discarded. If the connection is lost while two
|
||||
// goroutines are operating on the socket, it probably doesn't matter which
|
||||
// error we expose so we don't try to do anything more complex.
|
||||
l.connectionLock.Lock()
|
||||
if l.err == nil {
|
||||
l.err = err
|
||||
}
|
||||
l.cn.Close()
|
||||
l.connectionLock.Unlock()
|
||||
|
||||
// There might be a query in-flight; make sure nobody's waiting for a
|
||||
// response to it, since there's not going to be one.
|
||||
close(l.replyChan)
|
||||
|
||||
// let the listener know we're done
|
||||
close(l.notificationChan)
|
||||
|
||||
// this ListenerConn is done
|
||||
}
|
||||
|
||||
// Send a LISTEN query to the server. See ExecSimpleQuery.
|
||||
func (l *ListenerConn) Listen(channel string) (bool, error) {
|
||||
return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel))
|
||||
}
|
||||
|
||||
// Send an UNLISTEN query to the server. See ExecSimpleQuery.
|
||||
func (l *ListenerConn) Unlisten(channel string) (bool, error) {
|
||||
return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel))
|
||||
}
|
||||
|
||||
// Send `UNLISTEN *` to the server. See ExecSimpleQuery.
|
||||
func (l *ListenerConn) UnlistenAll() (bool, error) {
|
||||
return l.ExecSimpleQuery("UNLISTEN *")
|
||||
}
|
||||
|
||||
// Ping the remote server to make sure it's alive. Non-nil error means the
|
||||
// connection has failed and should be abandoned.
|
||||
func (l *ListenerConn) Ping() error {
|
||||
sent, err := l.ExecSimpleQuery("")
|
||||
if !sent {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
// shouldn't happen
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attempt to send a query on the connection. Returns an error if sending the
|
||||
// query failed, and the caller should initiate closure of this connection.
|
||||
// The caller must be holding senderLock (see acquireSenderLock and
|
||||
// releaseSenderLock).
|
||||
func (l *ListenerConn) sendSimpleQuery(q string) (err error) {
|
||||
defer errRecoverNoErrBadConn(&err)
|
||||
|
||||
// must set connection state before sending the query
|
||||
if !l.setState(connStateExpectResponse) {
|
||||
panic("two queries running at the same time")
|
||||
}
|
||||
|
||||
// Can't use l.cn.writeBuf here because it uses the scratch buffer which
|
||||
// might get overwritten by listenerConnLoop.
|
||||
b := &writeBuf{
|
||||
buf: []byte("Q\x00\x00\x00\x00"),
|
||||
pos: 1,
|
||||
}
|
||||
b.string(q)
|
||||
l.cn.send(b)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute a "simple query" (i.e. one with no bindable parameters) on the
|
||||
// connection. The possible return values are:
|
||||
// 1) "executed" is true; the query was executed to completion on the
|
||||
// database server. If the query failed, err will be set to the error
|
||||
// returned by the database, otherwise err will be nil.
|
||||
// 2) If "executed" is false, the query could not be executed on the remote
|
||||
// server. err will be non-nil.
|
||||
//
|
||||
// After a call to ExecSimpleQuery has returned an executed=false value, the
|
||||
// connection has either been closed or will be closed shortly thereafter, and
|
||||
// all subsequently executed queries will return an error.
|
||||
func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) {
|
||||
if err = l.acquireSenderLock(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer l.releaseSenderLock()
|
||||
|
||||
err = l.sendSimpleQuery(q)
|
||||
if err != nil {
|
||||
// We can't know what state the protocol is in, so we need to abandon
|
||||
// this connection.
|
||||
l.connectionLock.Lock()
|
||||
// Set the error pointer if it hasn't been set already; see
|
||||
// listenerConnMain.
|
||||
if l.err == nil {
|
||||
l.err = err
|
||||
}
|
||||
l.connectionLock.Unlock()
|
||||
l.cn.c.Close()
|
||||
return false, err
|
||||
}
|
||||
|
||||
// now we just wait for a reply..
|
||||
for {
|
||||
m, ok := <-l.replyChan
|
||||
if !ok {
|
||||
// We lost the connection to server, don't bother waiting for a
|
||||
// a response. err should have been set already.
|
||||
l.connectionLock.Lock()
|
||||
err := l.err
|
||||
l.connectionLock.Unlock()
|
||||
return false, err
|
||||
}
|
||||
switch m.typ {
|
||||
case 'Z':
|
||||
// sanity check
|
||||
if m.err != nil {
|
||||
panic("m.err != nil")
|
||||
}
|
||||
// done; err might or might not be set
|
||||
return true, err
|
||||
|
||||
case 'E':
|
||||
// sanity check
|
||||
if m.err == nil {
|
||||
panic("m.err == nil")
|
||||
}
|
||||
// server responded with an error; ReadyForQuery to follow
|
||||
err = m.err
|
||||
|
||||
default:
|
||||
return false, fmt.Errorf("unknown response for simple query: %q", m.typ)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ListenerConn) Close() error {
|
||||
l.connectionLock.Lock()
|
||||
if l.err != nil {
|
||||
l.connectionLock.Unlock()
|
||||
return errListenerConnClosed
|
||||
}
|
||||
l.err = errListenerConnClosed
|
||||
l.connectionLock.Unlock()
|
||||
// We can't send anything on the connection without holding senderLock.
|
||||
// Simply close the net.Conn to wake up everyone operating on it.
|
||||
return l.cn.c.Close()
|
||||
}
|
||||
|
||||
// Err() returns the reason the connection was closed. It is not safe to call
|
||||
// this function until l.Notify has been closed.
|
||||
func (l *ListenerConn) Err() error {
|
||||
return l.err
|
||||
}
|
||||
|
||||
var errListenerClosed = errors.New("pq: Listener has been closed")
|
||||
|
||||
var ErrChannelAlreadyOpen = errors.New("pq: channel is already open")
|
||||
var ErrChannelNotOpen = errors.New("pq: channel is not open")
|
||||
|
||||
type ListenerEventType int
|
||||
|
||||
const (
|
||||
// Emitted only when the database connection has been initially
|
||||
// initialized. err will always be nil.
|
||||
ListenerEventConnected ListenerEventType = iota
|
||||
|
||||
// Emitted after a database connection has been lost, either because of an
|
||||
// error or because Close has been called. err will be set to the reason
|
||||
// the database connection was lost.
|
||||
ListenerEventDisconnected
|
||||
|
||||
// Emitted after a database connection has been re-established after
|
||||
// connection loss. err will always be nil. After this event has been
|
||||
// emitted, a nil pq.Notification is sent on the Listener.Notify channel.
|
||||
ListenerEventReconnected
|
||||
|
||||
// Emitted after a connection to the database was attempted, but failed.
|
||||
// err will be set to an error describing why the connection attempt did
|
||||
// not succeed.
|
||||
ListenerEventConnectionAttemptFailed
|
||||
)
|
||||
|
||||
type EventCallbackType func(event ListenerEventType, err error)
|
||||
|
||||
// Listener provides an interface for listening to notifications from a
|
||||
// PostgreSQL database. For general usage information, see section
|
||||
// "Notifications".
|
||||
//
|
||||
// Listener can safely be used from concurrently running goroutines.
|
||||
type Listener struct {
|
||||
// Channel for receiving notifications from the database. In some cases a
|
||||
// nil value will be sent. See section "Notifications" above.
|
||||
Notify chan *Notification
|
||||
|
||||
name string
|
||||
minReconnectInterval time.Duration
|
||||
maxReconnectInterval time.Duration
|
||||
dialer Dialer
|
||||
eventCallback EventCallbackType
|
||||
|
||||
lock sync.Mutex
|
||||
isClosed bool
|
||||
reconnectCond *sync.Cond
|
||||
cn *ListenerConn
|
||||
connNotificationChan <-chan *Notification
|
||||
channels map[string]struct{}
|
||||
}
|
||||
|
||||
// NewListener creates a new database connection dedicated to LISTEN / NOTIFY.
|
||||
//
|
||||
// name should be set to a connection string to be used to establish the
|
||||
// database connection (see section "Connection String Parameters" above).
|
||||
//
|
||||
// minReconnectInterval controls the duration to wait before trying to
|
||||
// re-establish the database connection after connection loss. After each
|
||||
// consecutive failure this interval is doubled, until maxReconnectInterval is
|
||||
// reached. Successfully completing the connection establishment procedure
|
||||
// resets the interval back to minReconnectInterval.
|
||||
//
|
||||
// The last parameter eventCallback can be set to a function which will be
|
||||
// called by the Listener when the state of the underlying database connection
|
||||
// changes. This callback will be called by the goroutine which dispatches the
|
||||
// notifications over the Notify channel, so you should try to avoid doing
|
||||
// potentially time-consuming operations from the callback.
|
||||
func NewListener(name string,
|
||||
minReconnectInterval time.Duration,
|
||||
maxReconnectInterval time.Duration,
|
||||
eventCallback EventCallbackType) *Listener {
|
||||
return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback)
|
||||
}
|
||||
|
||||
// NewDialListener is like NewListener but it takes a Dialer.
|
||||
func NewDialListener(d Dialer,
|
||||
name string,
|
||||
minReconnectInterval time.Duration,
|
||||
maxReconnectInterval time.Duration,
|
||||
eventCallback EventCallbackType) *Listener {
|
||||
|
||||
l := &Listener{
|
||||
name: name,
|
||||
minReconnectInterval: minReconnectInterval,
|
||||
maxReconnectInterval: maxReconnectInterval,
|
||||
dialer: d,
|
||||
eventCallback: eventCallback,
|
||||
|
||||
channels: make(map[string]struct{}),
|
||||
|
||||
Notify: make(chan *Notification, 32),
|
||||
}
|
||||
l.reconnectCond = sync.NewCond(&l.lock)
|
||||
|
||||
go l.listenerMain()
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// Returns the notification channel for this listener. This is the same
|
||||
// channel as Notify, and will not be recreated during the life time of the
|
||||
// Listener.
|
||||
func (l *Listener) NotificationChannel() <-chan *Notification {
|
||||
return l.Notify
|
||||
}
|
||||
|
||||
// Listen starts listening for notifications on a channel. Calls to this
|
||||
// function will block until an acknowledgement has been received from the
|
||||
// server. Note that Listener automatically re-establishes the connection
|
||||
// after connection loss, so this function may block indefinitely if the
|
||||
// connection can not be re-established.
|
||||
//
|
||||
// Listen will only fail in three conditions:
|
||||
// 1) The channel is already open. The returned error will be
|
||||
// ErrChannelAlreadyOpen.
|
||||
// 2) The query was executed on the remote server, but PostgreSQL returned an
|
||||
// error message in response to the query. The returned error will be a
|
||||
// pq.Error containing the information the server supplied.
|
||||
// 3) Close is called on the Listener before the request could be completed.
|
||||
//
|
||||
// The channel name is case-sensitive.
|
||||
func (l *Listener) Listen(channel string) error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.isClosed {
|
||||
return errListenerClosed
|
||||
}
|
||||
|
||||
// The server allows you to issue a LISTEN on a channel which is already
|
||||
// open, but it seems useful to be able to detect this case to spot for
|
||||
// mistakes in application logic. If the application genuinely does't
|
||||
// care, it can check the exported error and ignore it.
|
||||
_, exists := l.channels[channel]
|
||||
if exists {
|
||||
return ErrChannelAlreadyOpen
|
||||
}
|
||||
|
||||
if l.cn != nil {
|
||||
// If gotResponse is true but error is set, the query was executed on
|
||||
// the remote server, but resulted in an error. This should be
|
||||
// relatively rare, so it's fine if we just pass the error to our
|
||||
// caller. However, if gotResponse is false, we could not complete the
|
||||
// query on the remote server and our underlying connection is about
|
||||
// to go away, so we only add relname to l.channels, and wait for
|
||||
// resync() to take care of the rest.
|
||||
gotResponse, err := l.cn.Listen(channel)
|
||||
if gotResponse && err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
l.channels[channel] = struct{}{}
|
||||
for l.cn == nil {
|
||||
l.reconnectCond.Wait()
|
||||
// we let go of the mutex for a while
|
||||
if l.isClosed {
|
||||
return errListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unlisten removes a channel from the Listener's channel list. Returns
|
||||
// ErrChannelNotOpen if the Listener is not listening on the specified channel.
|
||||
// Returns immediately with no error if there is no connection. Note that you
|
||||
// might still get notifications for this channel even after Unlisten has
|
||||
// returned.
|
||||
//
|
||||
// The channel name is case-sensitive.
|
||||
func (l *Listener) Unlisten(channel string) error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.isClosed {
|
||||
return errListenerClosed
|
||||
}
|
||||
|
||||
// Similarly to LISTEN, this is not an error in Postgres, but it seems
|
||||
// useful to distinguish from the normal conditions.
|
||||
_, exists := l.channels[channel]
|
||||
if !exists {
|
||||
return ErrChannelNotOpen
|
||||
}
|
||||
|
||||
if l.cn != nil {
|
||||
// Similarly to Listen (see comment in that function), the caller
|
||||
// should only be bothered with an error if it came from the backend as
|
||||
// a response to our query.
|
||||
gotResponse, err := l.cn.Unlisten(channel)
|
||||
if gotResponse && err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Don't bother waiting for resync if there's no connection.
|
||||
delete(l.channels, channel)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnlistenAll removes all channels from the Listener's channel list. Returns
|
||||
// immediately with no error if there is no connection. Note that you might
|
||||
// still get notifications for any of the deleted channels even after
|
||||
// UnlistenAll has returned.
|
||||
func (l *Listener) UnlistenAll() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.isClosed {
|
||||
return errListenerClosed
|
||||
}
|
||||
|
||||
if l.cn != nil {
|
||||
// Similarly to Listen (see comment in that function), the caller
|
||||
// should only be bothered with an error if it came from the backend as
|
||||
// a response to our query.
|
||||
gotResponse, err := l.cn.UnlistenAll()
|
||||
if gotResponse && err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Don't bother waiting for resync if there's no connection.
|
||||
l.channels = make(map[string]struct{})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping the remote server to make sure it's alive. Non-nil return value means
|
||||
// that there is no active connection.
|
||||
func (l *Listener) Ping() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.isClosed {
|
||||
return errListenerClosed
|
||||
}
|
||||
if l.cn == nil {
|
||||
return errors.New("no connection")
|
||||
}
|
||||
|
||||
return l.cn.Ping()
|
||||
}
|
||||
|
||||
// Clean up after losing the server connection. Returns l.cn.Err(), which
|
||||
// should have the reason the connection was lost.
|
||||
func (l *Listener) disconnectCleanup() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
// sanity check; can't look at Err() until the channel has been closed
|
||||
select {
|
||||
case _, ok := <-l.connNotificationChan:
|
||||
if ok {
|
||||
panic("connNotificationChan not closed")
|
||||
}
|
||||
default:
|
||||
panic("connNotificationChan not closed")
|
||||
}
|
||||
|
||||
err := l.cn.Err()
|
||||
l.cn.Close()
|
||||
l.cn = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Synchronize the list of channels we want to be listening on with the server
|
||||
// after the connection has been established.
|
||||
func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error {
|
||||
doneChan := make(chan error)
|
||||
go func() {
|
||||
for channel := range l.channels {
|
||||
// If we got a response, return that error to our caller as it's
|
||||
// going to be more descriptive than cn.Err().
|
||||
gotResponse, err := cn.Listen(channel)
|
||||
if gotResponse && err != nil {
|
||||
doneChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
// If we couldn't reach the server, wait for notificationChan to
|
||||
// close and then return the error message from the connection, as
|
||||
// per ListenerConn's interface.
|
||||
if err != nil {
|
||||
for _ = range notificationChan {
|
||||
}
|
||||
doneChan <- cn.Err()
|
||||
return
|
||||
}
|
||||
}
|
||||
doneChan <- nil
|
||||
}()
|
||||
|
||||
// Ignore notifications while synchronization is going on to avoid
|
||||
// deadlocks. We have to send a nil notification over Notify anyway as
|
||||
// we can't possibly know which notifications (if any) were lost while
|
||||
// the connection was down, so there's no reason to try and process
|
||||
// these messages at all.
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-notificationChan:
|
||||
if !ok {
|
||||
notificationChan = nil
|
||||
}
|
||||
|
||||
case err := <-doneChan:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// caller should NOT be holding l.lock
|
||||
func (l *Listener) closed() bool {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
return l.isClosed
|
||||
}
|
||||
|
||||
func (l *Listener) connect() error {
|
||||
notificationChan := make(chan *Notification, 32)
|
||||
cn, err := newDialListenerConn(l.dialer, l.name, notificationChan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
err = l.resync(cn, notificationChan)
|
||||
if err != nil {
|
||||
cn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
l.cn = cn
|
||||
l.connNotificationChan = notificationChan
|
||||
l.reconnectCond.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close disconnects the Listener from the database and shuts it down.
|
||||
// Subsequent calls to its methods will return an error. Close returns an
|
||||
// error if the connection has already been closed.
|
||||
func (l *Listener) Close() error {
|
||||
l.lock.Lock()
|
||||
defer l.lock.Unlock()
|
||||
|
||||
if l.isClosed {
|
||||
return errListenerClosed
|
||||
}
|
||||
|
||||
if l.cn != nil {
|
||||
l.cn.Close()
|
||||
}
|
||||
l.isClosed = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) emitEvent(event ListenerEventType, err error) {
|
||||
if l.eventCallback != nil {
|
||||
l.eventCallback(event, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Main logic here: maintain a connection to the server when possible, wait
|
||||
// for notifications and emit events.
|
||||
func (l *Listener) listenerConnLoop() {
|
||||
var nextReconnect time.Time
|
||||
|
||||
reconnectInterval := l.minReconnectInterval
|
||||
for {
|
||||
for {
|
||||
err := l.connect()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if l.closed() {
|
||||
return
|
||||
}
|
||||
l.emitEvent(ListenerEventConnectionAttemptFailed, err)
|
||||
|
||||
time.Sleep(reconnectInterval)
|
||||
reconnectInterval *= 2
|
||||
if reconnectInterval > l.maxReconnectInterval {
|
||||
reconnectInterval = l.maxReconnectInterval
|
||||
}
|
||||
}
|
||||
|
||||
if nextReconnect.IsZero() {
|
||||
l.emitEvent(ListenerEventConnected, nil)
|
||||
} else {
|
||||
l.emitEvent(ListenerEventReconnected, nil)
|
||||
l.Notify <- nil
|
||||
}
|
||||
|
||||
reconnectInterval = l.minReconnectInterval
|
||||
nextReconnect = time.Now().Add(reconnectInterval)
|
||||
|
||||
for {
|
||||
notification, ok := <-l.connNotificationChan
|
||||
if !ok {
|
||||
// lost connection, loop again
|
||||
break
|
||||
}
|
||||
l.Notify <- notification
|
||||
}
|
||||
|
||||
err := l.disconnectCleanup()
|
||||
if l.closed() {
|
||||
return
|
||||
}
|
||||
l.emitEvent(ListenerEventDisconnected, err)
|
||||
|
||||
time.Sleep(nextReconnect.Sub(time.Now()))
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) listenerMain() {
|
||||
l.listenerConnLoop()
|
||||
close(l.Notify)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue