mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-11 16:59:59 +08:00
Compare commits
177 Commits
tools/goct
...
go1.16
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ea0a843f8 | ||
|
|
9e0e01b2bc | ||
|
|
af50a80d01 | ||
|
|
703fb8d970 | ||
|
|
e964e530e1 | ||
|
|
52265087d1 | ||
|
|
b4c2677eb9 | ||
|
|
30296fb1ca | ||
|
|
356c80defd | ||
|
|
8c31525378 | ||
|
|
2cf09f3c36 | ||
|
|
d41e542c92 | ||
|
|
265a24ac6d | ||
|
|
7d88fc39dc | ||
|
|
6957b6a344 | ||
|
|
bca6a230c8 | ||
|
|
cc8413d683 | ||
|
|
3842283fa8 | ||
|
|
fe13a533f5 | ||
|
|
7a327ccda4 | ||
|
|
06e4507406 | ||
|
|
8794d5b753 | ||
|
|
9bfa63d995 | ||
|
|
a432b121fb | ||
|
|
b61c94bb66 | ||
|
|
93fcf899dc | ||
|
|
9f4b3bae92 | ||
|
|
805cb87d98 | ||
|
|
366131640e | ||
|
|
956884a3ff | ||
|
|
f571cb8af2 | ||
|
|
cc5acf3b90 | ||
|
|
e1aa665443 | ||
|
|
cd357d9484 | ||
|
|
6d4d7cbd6b | ||
|
|
c593b5b531 | ||
|
|
fd5b38b07c | ||
|
|
41efb48f55 | ||
|
|
0ef3626839 | ||
|
|
77a72b16e9 | ||
|
|
21566f1b7a | ||
|
|
b2646e228b | ||
|
|
588b883710 | ||
|
|
033910bbd8 | ||
|
|
530dd79e3f | ||
|
|
cd5263ac75 | ||
|
|
ea3302a468 | ||
|
|
abf15b373c | ||
|
|
a865e9ee29 | ||
|
|
f8292198cf | ||
|
|
016d965f56 | ||
|
|
95d7c73409 | ||
|
|
939ef2a181 | ||
|
|
f0b8dd45fe | ||
|
|
0ba9335b04 | ||
|
|
04f181f0b4 | ||
|
|
89f841c126 | ||
|
|
d785c8c377 | ||
|
|
687a1d15da | ||
|
|
aaa974e1ad | ||
|
|
2779568ccf | ||
|
|
f7d50ae626 | ||
|
|
33594ea350 | ||
|
|
ee2ec974c4 | ||
|
|
fd2f2f0f54 | ||
|
|
86a2429d7d | ||
|
|
e5fe5dcc50 | ||
|
|
b510e7c242 | ||
|
|
dfe92e709f | ||
|
|
cb649cf627 | ||
|
|
ce19a5ade6 | ||
|
|
6dc56de714 | ||
|
|
f3369f8e81 | ||
|
|
c9b05ae07e | ||
|
|
32a59dbc27 | ||
|
|
ba0dff2d61 | ||
|
|
10da5e0424 | ||
|
|
4bed34090f | ||
|
|
2bfecf9354 | ||
|
|
6d129e0264 | ||
|
|
a2df1bb164 | ||
|
|
5f02e623f5 | ||
|
|
963b52fb1b | ||
|
|
02265d0bfe | ||
|
|
2e57e91826 | ||
|
|
82c642d3f4 | ||
|
|
b2571883ca | ||
|
|
00ff50c2cc | ||
|
|
4d7fa08b0b | ||
|
|
367afb544c | ||
|
|
43b8c7f641 | ||
|
|
a2dcb0079a | ||
|
|
f9619328f2 | ||
|
|
bae061a67e | ||
|
|
0b176e17ac | ||
|
|
6340e24c17 | ||
|
|
74e0676617 | ||
|
|
0defb7522f | ||
|
|
0c786ca849 | ||
|
|
26c541b9cb | ||
|
|
ade6f9ee46 | ||
|
|
f4502171ea | ||
|
|
8157e2118d | ||
|
|
e52dace416 | ||
|
|
dc260f196a | ||
|
|
559726112c | ||
|
|
a5fcf24c04 | ||
|
|
fc9b3ffdc1 | ||
|
|
e71c505e94 | ||
|
|
21c49009c0 | ||
|
|
69d355eb4b | ||
|
|
83f88d177f | ||
|
|
641ebf1667 | ||
|
|
cf435bfcc1 | ||
|
|
28f1b15b8e | ||
|
|
42413dc294 | ||
|
|
ec7ac43948 | ||
|
|
deefc1a8eb | ||
|
|
036328f1ea | ||
|
|
85057a623d | ||
|
|
1c544a26be | ||
|
|
20a61ce43e | ||
|
|
dd294e8cd6 | ||
|
|
3e9d0161bc | ||
|
|
cf6c349118 | ||
|
|
c7a0ec428c | ||
|
|
ce1c02f4f9 | ||
|
|
c3756a8f1c | ||
|
|
f4fd735aee | ||
|
|
683d793719 | ||
|
|
affbcb5698 | ||
|
|
f0d1722bbd | ||
|
|
c4f8eca459 | ||
|
|
251c071418 | ||
|
|
6652c4e445 | ||
|
|
f73613dff0 | ||
|
|
7a75dce465 | ||
|
|
801f1adf71 | ||
|
|
f76b976262 | ||
|
|
a49f9060c2 | ||
|
|
ebe28882eb | ||
|
|
fdc57d07d7 | ||
|
|
ef22042f4d | ||
|
|
944193ce25 | ||
|
|
dcfc9b79f1 | ||
|
|
b7052854bb | ||
|
|
4729a16142 | ||
|
|
3604659027 | ||
|
|
9f7f94b673 | ||
|
|
0b3629b636 | ||
|
|
a644ec7edd | ||
|
|
9941055eaa | ||
|
|
10fd9131a1 | ||
|
|
90828a0d4a | ||
|
|
b1c3c21c81 | ||
|
|
97a8b3ade5 | ||
|
|
95a5f64493 | ||
|
|
20e659749a | ||
|
|
94708cc78f | ||
|
|
06fafd2153 | ||
|
|
79de932646 | ||
|
|
b562e940e7 | ||
|
|
69068cdaf0 | ||
|
|
f25788ebea | ||
|
|
1293c4321b | ||
|
|
e3e08a7396 | ||
|
|
4b071f4c33 | ||
|
|
81831b60a9 | ||
|
|
1677a4dceb | ||
|
|
dac3600b53 | ||
|
|
3db64c7d47 | ||
|
|
7eb6aae949 | ||
|
|
07128213d6 | ||
|
|
9504d30049 | ||
|
|
ce73b9a85c | ||
|
|
4d2a146733 | ||
|
|
46e236fef7 |
@@ -1,3 +1,6 @@
|
||||
comment: false
|
||||
comment:
|
||||
layout: "flags, files"
|
||||
behavior: once
|
||||
require_changes: true
|
||||
ignore:
|
||||
- "tools"
|
||||
19
.github/workflows/go.yml
vendored
19
.github/workflows/go.yml
vendored
@@ -11,15 +11,17 @@ jobs:
|
||||
name: Linux
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ^1.16
|
||||
check-latest: true
|
||||
cache: true
|
||||
id: go
|
||||
|
||||
- name: Check out code into the Go module directory
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Get dependencies
|
||||
run: |
|
||||
go get -v -t -d ./...
|
||||
@@ -40,13 +42,16 @@ jobs:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout codebase
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Go 1.x
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
go-version: ^1.16
|
||||
|
||||
- name: Checkout codebase
|
||||
uses: actions/checkout@v3
|
||||
# use 1.16 to guarantee Go 1.16 compatibility
|
||||
go-version: 1.16
|
||||
check-latest: true
|
||||
cache: true
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
|
||||
28
ROADMAP.md
28
ROADMAP.md
@@ -1,28 +0,0 @@
|
||||
# go-zero Roadmap
|
||||
|
||||
This document defines a high level roadmap for go-zero development and upcoming releases.
|
||||
Community and contributor involvement is vital for successfully implementing all desired items for each release.
|
||||
We hope that the items listed below will inspire further engagement from the community to keep go-zero progressing and shipping exciting and valuable features.
|
||||
|
||||
## 2021 Q2
|
||||
- [x] Support service discovery through K8S client api
|
||||
- [x] Log full sql statements for easier sql problem solving
|
||||
|
||||
## 2021 Q3
|
||||
- [x] Support `goctl model pg` to support PostgreSQL code generation
|
||||
- [x] Adapt builtin tracing mechanism to opentracing solutions
|
||||
|
||||
## 2021 Q4
|
||||
- [x] Support `username/password` authentication in ETCD
|
||||
- [x] Support `SSL/TLS` in ETCD
|
||||
- [x] Support `SSL/TLS` in `zRPC`
|
||||
- [x] Support `TLS` in redis connections
|
||||
- [x] Support `goctl bug` to report bugs conveniently
|
||||
|
||||
## 2022
|
||||
- [x] Support `context` in redis related methods for timeout and tracing
|
||||
- [x] Support `context` in sql related methods for timeout and tracing
|
||||
- [x] Support `context` in mongodb related methods for timeout and tracing
|
||||
- [x] Add `httpc.Do` with HTTP call governance, like circuit breaker etc.
|
||||
- [ ] Support `goctl doctor` command to report potential issues for given service
|
||||
- [ ] Support `goctl mock` command to start a mocking server with given `.api` file
|
||||
@@ -20,16 +20,16 @@ func (b noOpBreaker) Do(req func() error) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
func (b noOpBreaker) DoWithAcceptable(req func() error, acceptable Acceptable) error {
|
||||
func (b noOpBreaker) DoWithAcceptable(req func() error, _ Acceptable) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
func (b noOpBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
|
||||
func (b noOpBreaker) DoWithFallback(req func() error, _ func(err error) error) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
|
||||
acceptable Acceptable) error {
|
||||
func (b noOpBreaker) DoWithFallbackAcceptable(req func() error, _ func(err error) error,
|
||||
_ Acceptable) error {
|
||||
return req()
|
||||
}
|
||||
|
||||
@@ -38,5 +38,5 @@ type nopPromise struct{}
|
||||
func (p nopPromise) Accept() {
|
||||
}
|
||||
|
||||
func (p nopPromise) Reject(reason string) {
|
||||
func (p nopPromise) Reject(_ string) {
|
||||
}
|
||||
|
||||
@@ -32,9 +32,11 @@ func NewECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
// BlockSize returns the mode's block size.
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
// why we don't return error is because cipher.BlockMode doesn't allow this
|
||||
// CryptBlocks encrypts a number of blocks. The length of src must be a multiple of
|
||||
// the block size. Dst and src must overlap entirely or not at all.
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
@@ -59,11 +61,13 @@ func NewECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
// BlockSize returns the mode's block size.
|
||||
func (x *ecbDecrypter) BlockSize() int {
|
||||
return x.blockSize
|
||||
}
|
||||
|
||||
// why we don't return error is because cipher.BlockMode doesn't allow this
|
||||
// CryptBlocks decrypts a number of blocks. The length of src must be a multiple of
|
||||
// the block size. Dst and src must overlap entirely or not at all.
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
logx.Error("crypto/cipher: input not full blocks")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
@@ -10,7 +11,8 @@ import (
|
||||
func TestAesEcb(t *testing.T) {
|
||||
var (
|
||||
key = []byte("q4t7w!z%C*F-JaNdRgUjXn2r5u8x/A?D")
|
||||
val = []byte("hello")
|
||||
val = []byte("helloworld")
|
||||
valLong = []byte("helloworldlong..")
|
||||
badKey1 = []byte("aaaaaaaaa")
|
||||
// more than 32 chars
|
||||
badKey2 = []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
|
||||
@@ -31,6 +33,39 @@ func TestAesEcb(t *testing.T) {
|
||||
src, err := EcbDecrypt(key, dst)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, val, src)
|
||||
block, err := aes.NewCipher(key)
|
||||
assert.NoError(t, err)
|
||||
encrypter := NewECBEncrypter(block)
|
||||
assert.Equal(t, 16, encrypter.BlockSize())
|
||||
decrypter := NewECBDecrypter(block)
|
||||
assert.Equal(t, 16, decrypter.BlockSize())
|
||||
|
||||
dst = make([]byte, 8)
|
||||
encrypter.CryptBlocks(dst, val)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
dst = make([]byte, 8)
|
||||
encrypter.CryptBlocks(dst, valLong)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
dst = make([]byte, 8)
|
||||
decrypter.CryptBlocks(dst, val)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
dst = make([]byte, 8)
|
||||
decrypter.CryptBlocks(dst, valLong)
|
||||
for _, b := range dst {
|
||||
assert.Equal(t, byte(0), b)
|
||||
}
|
||||
|
||||
_, err = EcbEncryptBase64("cTR0N3dDKkYtSmFOZFJnVWpYbjJyNXU4eC9BP0QK", "aGVsbG93b3JsZGxvbmcuLgo=")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAesEcbBase64(t *testing.T) {
|
||||
|
||||
@@ -80,3 +80,17 @@ func TestKeyBytes(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, len(key.Bytes()) > 0)
|
||||
}
|
||||
|
||||
func TestDHOnErrors(t *testing.T) {
|
||||
key, err := GenerateKey()
|
||||
assert.Nil(t, err)
|
||||
assert.NotEmpty(t, key.Bytes())
|
||||
_, err = ComputeKey(key.PubKey, key.PriKey)
|
||||
assert.NoError(t, err)
|
||||
_, err = ComputeKey(nil, key.PriKey)
|
||||
assert.Error(t, err)
|
||||
_, err = ComputeKey(key.PubKey, nil)
|
||||
assert.Error(t, err)
|
||||
|
||||
assert.NotNil(t, NewPublicKey([]byte("")))
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import "sync"
|
||||
type Ring struct {
|
||||
elements []interface{}
|
||||
index int
|
||||
lock sync.Mutex
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRing returns a Ring object with the given size n.
|
||||
@@ -31,8 +31,8 @@ func (r *Ring) Add(v interface{}) {
|
||||
|
||||
// Take takes all items from r.
|
||||
func (r *Ring) Take() []interface{} {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.lock.RLock()
|
||||
defer r.lock.RUnlock()
|
||||
|
||||
var size int
|
||||
var start int
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewSet() *Set {
|
||||
}
|
||||
}
|
||||
|
||||
// NewUnmanagedSet returns a unmanaged Set, which can put values with different types.
|
||||
// NewUnmanagedSet returns an unmanaged Set, which can put values with different types.
|
||||
func NewUnmanagedSet() *Set {
|
||||
return &Set{
|
||||
data: make(map[interface{}]lang.PlaceholderType),
|
||||
@@ -213,23 +213,23 @@ func (s *Set) validate(i interface{}) {
|
||||
switch i.(type) {
|
||||
case int:
|
||||
if s.tp != intType {
|
||||
logx.Errorf("Error: element is int, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is int, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case int64:
|
||||
if s.tp != int64Type {
|
||||
logx.Errorf("Error: element is int64, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is int64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint:
|
||||
if s.tp != uintType {
|
||||
logx.Errorf("Error: element is uint, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is uint, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case uint64:
|
||||
if s.tp != uint64Type {
|
||||
logx.Errorf("Error: element is uint64, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is uint64, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
case string:
|
||||
if s.tp != stringType {
|
||||
logx.Errorf("Error: element is string, but set contains elements with type %d", s.tp)
|
||||
logx.Errorf("element is string, but set contains elements with type %d", s.tp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,10 +69,11 @@ func NewTimingWheel(interval time.Duration, numSlots int, execute Execute) (*Tim
|
||||
interval, numSlots, execute)
|
||||
}
|
||||
|
||||
return newTimingWheelWithClock(interval, numSlots, execute, timex.NewTicker(interval))
|
||||
return NewTimingWheelWithTicker(interval, numSlots, execute, timex.NewTicker(interval))
|
||||
}
|
||||
|
||||
func newTimingWheelWithClock(interval time.Duration, numSlots int, execute Execute,
|
||||
// NewTimingWheelWithTicker returns a TimingWheel with the given ticker.
|
||||
func NewTimingWheelWithTicker(interval time.Duration, numSlots int, execute Execute,
|
||||
ticker timex.Ticker) (*TimingWheel, error) {
|
||||
tw := &TimingWheel{
|
||||
interval: interval,
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestNewTimingWheel(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_Drain(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
}, ticker)
|
||||
tw.SetTimer("first", 3, testStep*4)
|
||||
tw.SetTimer("second", 5, testStep*7)
|
||||
@@ -62,7 +62,7 @@ func TestTimingWheel_Drain(t *testing.T) {
|
||||
func TestTimingWheel_SetTimerSoon(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -78,7 +78,7 @@ func TestTimingWheel_SetTimerSoon(t *testing.T) {
|
||||
func TestTimingWheel_SetTimerTwice(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 5, v.(int))
|
||||
@@ -96,7 +96,7 @@ func TestTimingWheel_SetTimerTwice(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
defer tw.Stop()
|
||||
assert.NotPanics(t, func() {
|
||||
tw.SetTimer("any", 3, -testStep)
|
||||
@@ -105,7 +105,7 @@ func TestTimingWheel_SetTimerWrongDelay(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw.Stop()
|
||||
assert.Equal(t, ErrClosed, tw.SetTimer("any", 3, testStep))
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func TestTimingWheel_SetTimerAfterClose(t *testing.T) {
|
||||
func TestTimingWheel_MoveTimer(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -139,7 +139,7 @@ func TestTimingWheel_MoveTimer(t *testing.T) {
|
||||
func TestTimingWheel_MoveTimerSoon(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 3, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 3, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -155,7 +155,7 @@ func TestTimingWheel_MoveTimerSoon(t *testing.T) {
|
||||
func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
|
||||
run := syncx.NewAtomicBool()
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.True(t, run.CompareAndSwap(false, true))
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
@@ -173,7 +173,7 @@ func TestTimingWheel_MoveTimerEarlier(t *testing.T) {
|
||||
|
||||
func TestTimingWheel_RemoveTimer(t *testing.T) {
|
||||
ticker := timex.NewFakeTicker()
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {}, ticker)
|
||||
tw.SetTimer("any", 3, testStep)
|
||||
assert.NotPanics(t, func() {
|
||||
tw.RemoveTimer("any")
|
||||
@@ -236,7 +236,7 @@ func TestTimingWheel_SetTimer(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
assert.Equal(t, 1, key.(int))
|
||||
assert.Equal(t, 2, value.(int))
|
||||
actual = atomic.LoadInt32(&count)
|
||||
@@ -317,7 +317,7 @@ func TestTimingWheel_SetAndMoveThenStart(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -405,7 +405,7 @@ func TestTimingWheel_SetAndMoveTwice(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -486,7 +486,7 @@ func TestTimingWheel_ElapsedAndSet(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -577,7 +577,7 @@ func TestTimingWheel_ElapsedAndSetThenMove(t *testing.T) {
|
||||
}
|
||||
var actual int32
|
||||
done := make(chan lang.PlaceholderType)
|
||||
tw, err := newTimingWheelWithClock(testStep, test.slots, func(key, value interface{}) {
|
||||
tw, err := NewTimingWheelWithTicker(testStep, test.slots, func(key, value interface{}) {
|
||||
actual = atomic.LoadInt32(&count)
|
||||
close(done)
|
||||
}, ticker)
|
||||
@@ -612,7 +612,7 @@ func TestMoveAndRemoveTask(t *testing.T) {
|
||||
}
|
||||
}
|
||||
var keys []int
|
||||
tw, _ := newTimingWheelWithClock(testStep, 10, func(k, v interface{}) {
|
||||
tw, _ := NewTimingWheelWithTicker(testStep, 10, func(k, v interface{}) {
|
||||
assert.Equal(t, "any", k)
|
||||
assert.Equal(t, 3, v.(int))
|
||||
keys = append(keys, v.(int))
|
||||
|
||||
@@ -5,16 +5,37 @@ import (
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/jsonx"
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
"github.com/zeromicro/go-zero/internal/encoding"
|
||||
)
|
||||
|
||||
var loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
const jsonTagKey = "json"
|
||||
|
||||
var (
|
||||
fillDefaultUnmarshaler = mapping.NewUnmarshaler(jsonTagKey, mapping.WithDefault())
|
||||
loaders = map[string]func([]byte, interface{}) error{
|
||||
".json": LoadFromJsonBytes,
|
||||
".toml": LoadFromTomlBytes,
|
||||
".yaml": LoadFromYamlBytes,
|
||||
".yml": LoadFromYamlBytes,
|
||||
}
|
||||
)
|
||||
|
||||
// children and mapField should not be both filled.
|
||||
// named fields and map cannot be bound to the same field name.
|
||||
type fieldInfo struct {
|
||||
children map[string]*fieldInfo
|
||||
mapField *fieldInfo
|
||||
}
|
||||
|
||||
// FillDefault fills the default values for the given v,
|
||||
// and the premise is that the value of v must be guaranteed to be empty.
|
||||
func FillDefault(v interface{}) error {
|
||||
return fillDefaultUnmarshaler.Unmarshal(map[string]interface{}{}, v)
|
||||
}
|
||||
|
||||
// Load loads config into v from file, .json, .yaml and .yml are acceptable.
|
||||
@@ -49,7 +70,19 @@ func LoadConfig(file string, v interface{}, opts ...Option) error {
|
||||
|
||||
// LoadFromJsonBytes loads config into v from content json bytes.
|
||||
func LoadFromJsonBytes(content []byte, v interface{}) error {
|
||||
return mapping.UnmarshalJsonBytes(content, v)
|
||||
info, err := buildFieldsInfo(reflect.TypeOf(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var m map[string]interface{}
|
||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lowerCaseKeyMap := toLowerCaseKeyMap(m, info)
|
||||
|
||||
return mapping.UnmarshalJsonMap(lowerCaseKeyMap, v, mapping.WithCanonicalKeyFunc(toLowerCase))
|
||||
}
|
||||
|
||||
// LoadConfigFromJsonBytes loads config into v from content json bytes.
|
||||
@@ -60,12 +93,22 @@ func LoadConfigFromJsonBytes(content []byte, v interface{}) error {
|
||||
|
||||
// LoadFromTomlBytes loads config into v from content toml bytes.
|
||||
func LoadFromTomlBytes(content []byte, v interface{}) error {
|
||||
return mapping.UnmarshalTomlBytes(content, v)
|
||||
b, err := encoding.TomlToJson(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return LoadFromJsonBytes(b, v)
|
||||
}
|
||||
|
||||
// LoadFromYamlBytes loads config into v from content yaml bytes.
|
||||
func LoadFromYamlBytes(content []byte, v interface{}) error {
|
||||
return mapping.UnmarshalYamlBytes(content, v)
|
||||
b, err := encoding.YamlToJson(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return LoadFromJsonBytes(b, v)
|
||||
}
|
||||
|
||||
// LoadConfigFromYamlBytes loads config into v from content yaml bytes.
|
||||
@@ -80,3 +123,205 @@ func MustLoad(path string, v interface{}, opts ...Option) {
|
||||
log.Fatalf("error: config file %s, %s", path, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func addOrMergeFields(info *fieldInfo, key string, child *fieldInfo) error {
|
||||
if prev, ok := info.children[key]; ok {
|
||||
if child.mapField != nil {
|
||||
return newDupKeyError(key)
|
||||
}
|
||||
|
||||
if err := mergeFields(prev, key, child.children); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
info.children[key] = child
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildAnonymousFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
fields, err := buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range fields.children {
|
||||
if err = addOrMergeFields(info, k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
elemField, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
mapField: elemField,
|
||||
}
|
||||
default:
|
||||
if _, ok := info.children[lowerCaseName]; ok {
|
||||
return newDupKeyError(lowerCaseName)
|
||||
}
|
||||
|
||||
info.children[lowerCaseName] = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
tp = mapping.Deref(tp)
|
||||
|
||||
switch tp.Kind() {
|
||||
case reflect.Struct:
|
||||
return buildStructFieldsInfo(tp)
|
||||
case reflect.Array, reflect.Slice:
|
||||
return buildFieldsInfo(mapping.Deref(tp.Elem()))
|
||||
case reflect.Chan, reflect.Func:
|
||||
return nil, fmt.Errorf("unsupported type: %s", tp.Kind())
|
||||
default:
|
||||
return &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildNamedFieldInfo(info *fieldInfo, lowerCaseName string, ft reflect.Type) error {
|
||||
var finfo *fieldInfo
|
||||
var err error
|
||||
|
||||
switch ft.Kind() {
|
||||
case reflect.Struct:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Array, reflect.Slice:
|
||||
finfo, err = buildFieldsInfo(ft.Elem())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Map:
|
||||
elemInfo, err := buildFieldsInfo(mapping.Deref(ft.Elem()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
finfo = &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
mapField: elemInfo,
|
||||
}
|
||||
default:
|
||||
finfo, err = buildFieldsInfo(ft)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return addOrMergeFields(info, lowerCaseName, finfo)
|
||||
}
|
||||
|
||||
func buildStructFieldsInfo(tp reflect.Type) (*fieldInfo, error) {
|
||||
info := &fieldInfo{
|
||||
children: make(map[string]*fieldInfo),
|
||||
}
|
||||
|
||||
for i := 0; i < tp.NumField(); i++ {
|
||||
field := tp.Field(i)
|
||||
name := field.Name
|
||||
lowerCaseName := toLowerCase(name)
|
||||
ft := mapping.Deref(field.Type)
|
||||
// flatten anonymous fields
|
||||
if field.Anonymous {
|
||||
if err := buildAnonymousFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err := buildNamedFieldInfo(info, lowerCaseName, ft); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func mergeFields(prev *fieldInfo, key string, children map[string]*fieldInfo) error {
|
||||
if len(prev.children) == 0 || len(children) == 0 {
|
||||
return newDupKeyError(key)
|
||||
}
|
||||
|
||||
// merge fields
|
||||
for k, v := range children {
|
||||
if _, ok := prev.children[k]; ok {
|
||||
return newDupKeyError(k)
|
||||
}
|
||||
|
||||
prev.children[k] = v
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func toLowerCase(s string) string {
|
||||
return strings.ToLower(s)
|
||||
}
|
||||
|
||||
func toLowerCaseInterface(v interface{}, info *fieldInfo) interface{} {
|
||||
switch vv := v.(type) {
|
||||
case map[string]interface{}:
|
||||
return toLowerCaseKeyMap(vv, info)
|
||||
case []interface{}:
|
||||
var arr []interface{}
|
||||
for _, vvv := range vv {
|
||||
arr = append(arr, toLowerCaseInterface(vvv, info))
|
||||
}
|
||||
return arr
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func toLowerCaseKeyMap(m map[string]interface{}, info *fieldInfo) map[string]interface{} {
|
||||
res := make(map[string]interface{})
|
||||
|
||||
for k, v := range m {
|
||||
ti, ok := info.children[k]
|
||||
if ok {
|
||||
res[k] = toLowerCaseInterface(v, ti)
|
||||
continue
|
||||
}
|
||||
|
||||
lk := toLowerCase(k)
|
||||
if ti, ok = info.children[lk]; ok {
|
||||
res[lk] = toLowerCaseInterface(v, ti)
|
||||
} else if info.mapField != nil {
|
||||
res[k] = toLowerCaseInterface(v, info.mapField)
|
||||
} else {
|
||||
res[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
type dupKeyError struct {
|
||||
key string
|
||||
}
|
||||
|
||||
func newDupKeyError(key string) dupKeyError {
|
||||
return dupKeyError{key: key}
|
||||
}
|
||||
|
||||
func (e dupKeyError) Error() string {
|
||||
return fmt.Sprintf("duplicated key %s", e.key)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
)
|
||||
|
||||
var dupErr dupKeyError
|
||||
|
||||
func TestLoadConfig_notExists(t *testing.T) {
|
||||
assert.NotNil(t, Load("not_a_file", nil))
|
||||
}
|
||||
@@ -17,7 +19,7 @@ func TestLoadConfig_notRecogFile(t *testing.T) {
|
||||
filename, err := fs.TempFilenameWithText("hello")
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(filename)
|
||||
assert.NotNil(t, Load(filename, nil))
|
||||
assert.NotNil(t, LoadConfig(filename, nil))
|
||||
}
|
||||
|
||||
func TestConfigJson(t *testing.T) {
|
||||
@@ -56,6 +58,22 @@ func TestConfigJson(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromJsonBytesArray(t *testing.T) {
|
||||
input := []byte(`{"users": [{"name": "foo"}, {"Name": "bar"}]}`)
|
||||
var val struct {
|
||||
Users []struct {
|
||||
Name string
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadConfigFromJsonBytes(input, &val))
|
||||
var expect []string
|
||||
for _, user := range val.Users {
|
||||
expect = append(expect, user.Name)
|
||||
}
|
||||
assert.EqualValues(t, []string{"foo", "bar"}, expect)
|
||||
}
|
||||
|
||||
func TestConfigToml(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
@@ -81,6 +99,89 @@ d = "abcd!@#$112"
|
||||
assert.Equal(t, "abcd!@#$112", val.D)
|
||||
}
|
||||
|
||||
func TestConfigOptional(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
c = "FOO"
|
||||
d = "abcd"
|
||||
`
|
||||
tmpfile, err := createTempFile(".toml", text)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(tmpfile)
|
||||
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
B int `json:"b,optional"`
|
||||
C string `json:"c,optional=B"`
|
||||
D string `json:"d,optional=b"`
|
||||
}
|
||||
if assert.NoError(t, Load(tmpfile, &val)) {
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "FOO", val.C)
|
||||
assert.Equal(t, "abcd", val.D)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigJsonCanonical(t *testing.T) {
|
||||
text := []byte(`{"a": "foo", "B": "bar"}`)
|
||||
|
||||
var val1 struct {
|
||||
A string `json:"a"`
|
||||
B string `json:"b"`
|
||||
}
|
||||
var val2 struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
assert.NoError(t, LoadFromJsonBytes(text, &val1))
|
||||
assert.Equal(t, "foo", val1.A)
|
||||
assert.Equal(t, "bar", val1.B)
|
||||
assert.NoError(t, LoadFromJsonBytes(text, &val2))
|
||||
assert.Equal(t, "foo", val2.A)
|
||||
assert.Equal(t, "bar", val2.B)
|
||||
}
|
||||
|
||||
func TestConfigTomlCanonical(t *testing.T) {
|
||||
text := []byte(`a = "foo"
|
||||
B = "bar"`)
|
||||
|
||||
var val1 struct {
|
||||
A string `json:"a"`
|
||||
B string `json:"b"`
|
||||
}
|
||||
var val2 struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
assert.NoError(t, LoadFromTomlBytes(text, &val1))
|
||||
assert.Equal(t, "foo", val1.A)
|
||||
assert.Equal(t, "bar", val1.B)
|
||||
assert.NoError(t, LoadFromTomlBytes(text, &val2))
|
||||
assert.Equal(t, "foo", val2.A)
|
||||
assert.Equal(t, "bar", val2.B)
|
||||
}
|
||||
|
||||
func TestConfigYamlCanonical(t *testing.T) {
|
||||
text := []byte(`a: foo
|
||||
B: bar`)
|
||||
|
||||
var val1 struct {
|
||||
A string `json:"a"`
|
||||
B string `json:"b"`
|
||||
}
|
||||
var val2 struct {
|
||||
A string
|
||||
B string
|
||||
}
|
||||
assert.NoError(t, LoadConfigFromYamlBytes(text, &val1))
|
||||
assert.Equal(t, "foo", val1.A)
|
||||
assert.Equal(t, "bar", val1.B)
|
||||
assert.NoError(t, LoadFromYamlBytes(text, &val2))
|
||||
assert.Equal(t, "foo", val2.A)
|
||||
assert.Equal(t, "bar", val2.B)
|
||||
}
|
||||
|
||||
func TestConfigTomlEnv(t *testing.T) {
|
||||
text := `a = "foo"
|
||||
b = 1
|
||||
@@ -143,6 +244,784 @@ func TestConfigJsonEnv(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToCamelCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
input: "",
|
||||
expect: "",
|
||||
},
|
||||
{
|
||||
input: "A",
|
||||
expect: "a",
|
||||
},
|
||||
{
|
||||
input: "a",
|
||||
expect: "a",
|
||||
},
|
||||
{
|
||||
input: "hello_world",
|
||||
expect: "hello_world",
|
||||
},
|
||||
{
|
||||
input: "Hello_world",
|
||||
expect: "hello_world",
|
||||
},
|
||||
{
|
||||
input: "hello_World",
|
||||
expect: "hello_world",
|
||||
},
|
||||
{
|
||||
input: "helloWorld",
|
||||
expect: "helloworld",
|
||||
},
|
||||
{
|
||||
input: "HelloWorld",
|
||||
expect: "helloworld",
|
||||
},
|
||||
{
|
||||
input: "hello World",
|
||||
expect: "hello world",
|
||||
},
|
||||
{
|
||||
input: "Hello World",
|
||||
expect: "hello world",
|
||||
},
|
||||
{
|
||||
input: "Hello World",
|
||||
expect: "hello world",
|
||||
},
|
||||
{
|
||||
input: "Hello World foo_bar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello World foo_Bar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello World Foo_bar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello World Foo_Bar",
|
||||
expect: "hello world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "Hello.World Foo_Bar",
|
||||
expect: "hello.world foo_bar",
|
||||
},
|
||||
{
|
||||
input: "你好 World Foo_Bar",
|
||||
expect: "你好 world foo_bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
assert.Equal(t, test.expect, toLowerCase(test.input))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromJsonBytesError(t *testing.T) {
|
||||
var val struct{}
|
||||
assert.Error(t, LoadFromJsonBytes([]byte(`hello`), &val))
|
||||
}
|
||||
|
||||
func TestLoadFromTomlBytesError(t *testing.T) {
|
||||
var val struct{}
|
||||
assert.Error(t, LoadFromTomlBytes([]byte(`hello`), &val))
|
||||
}
|
||||
|
||||
func TestLoadFromYamlBytesError(t *testing.T) {
|
||||
var val struct{}
|
||||
assert.Error(t, LoadFromYamlBytes([]byte(`':hello`), &val))
|
||||
}
|
||||
|
||||
func TestLoadFromYamlBytes(t *testing.T) {
|
||||
input := []byte(`layer1:
|
||||
layer2:
|
||||
layer3: foo`)
|
||||
var val struct {
|
||||
Layer1 struct {
|
||||
Layer2 struct {
|
||||
Layer3 string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromYamlBytes(input, &val))
|
||||
assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlBytesTerm(t *testing.T) {
|
||||
input := []byte(`layer1:
|
||||
layer2:
|
||||
tls_conf: foo`)
|
||||
var val struct {
|
||||
Layer1 struct {
|
||||
Layer2 struct {
|
||||
Layer3 string `json:"tls_conf"`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromYamlBytes(input, &val))
|
||||
assert.Equal(t, "foo", val.Layer1.Layer2.Layer3)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlBytesLayers(t *testing.T) {
|
||||
input := []byte(`layer1:
|
||||
layer2:
|
||||
layer3: foo`)
|
||||
var val struct {
|
||||
Value string `json:"Layer1.Layer2.Layer3"`
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromYamlBytes(input, &val))
|
||||
assert.Equal(t, "foo", val.Value)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlay(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis RedisKey
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Server
|
||||
Redis Redis
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlayReverse(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis Redis
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Redis RedisKey
|
||||
Server
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestLoadFromYamlItemOverlayWithMap(t *testing.T) {
|
||||
type (
|
||||
Redis struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
RedisKey struct {
|
||||
Redis
|
||||
Key string
|
||||
}
|
||||
|
||||
Server struct {
|
||||
Redis RedisKey
|
||||
}
|
||||
|
||||
TestConfig struct {
|
||||
Server
|
||||
Redis map[string]interface{}
|
||||
}
|
||||
)
|
||||
|
||||
input := []byte(`Redis:
|
||||
Host: localhost
|
||||
Port: 6379
|
||||
Key: test
|
||||
`)
|
||||
|
||||
var c TestConfig
|
||||
assert.ErrorAs(t, LoadFromYamlBytes(input, &c), &dupErr)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMap(t *testing.T) {
|
||||
input := []byte(`{"foo":{"/mtproto.RPCTos": "bff.bff","bar":"baz"}}`)
|
||||
|
||||
var val struct {
|
||||
Foo map[string]string
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.Equal(t, "bff.bff", val.Foo["/mtproto.RPCTos"])
|
||||
assert.Equal(t, "baz", val.Foo["bar"])
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMapWithSliceElements(t *testing.T) {
|
||||
input := []byte(`{"foo":{"/mtproto.RPCTos": ["bff.bff", "any"],"bar":["baz", "qux"]}}`)
|
||||
|
||||
var val struct {
|
||||
Foo map[string][]string
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.EqualValues(t, []string{"bff.bff", "any"}, val.Foo["/mtproto.RPCTos"])
|
||||
assert.EqualValues(t, []string{"baz", "qux"}, val.Foo["bar"])
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesMapWithSliceOfStructs(t *testing.T) {
|
||||
input := []byte(`{"foo":{
|
||||
"/mtproto.RPCTos": [{"bar": "any"}],
|
||||
"bar":[{"bar": "qux"}, {"bar": "ever"}]}}`)
|
||||
|
||||
var val struct {
|
||||
Foo map[string][]struct {
|
||||
Bar string
|
||||
}
|
||||
}
|
||||
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &val))
|
||||
assert.Equal(t, 1, len(val.Foo["/mtproto.RPCTos"]))
|
||||
assert.Equal(t, "any", val.Foo["/mtproto.RPCTos"][0].Bar)
|
||||
assert.Equal(t, 2, len(val.Foo["bar"]))
|
||||
assert.Equal(t, "qux", val.Foo["bar"][0].Bar)
|
||||
assert.Equal(t, "ever", val.Foo["bar"][1].Bar)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithAnonymousField(t *testing.T) {
|
||||
type (
|
||||
Int int
|
||||
|
||||
InnerConf struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Conf struct {
|
||||
Int
|
||||
InnerConf
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
input = []byte(`{"Name": "hello", "int": 3}`)
|
||||
c Conf
|
||||
)
|
||||
assert.NoError(t, LoadFromJsonBytes(input, &c))
|
||||
assert.Equal(t, "hello", c.Name)
|
||||
assert.Equal(t, Int(3), c.Int)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithMapValueOfStruct(t *testing.T) {
|
||||
type (
|
||||
Value struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Config struct {
|
||||
Items map[string]Value
|
||||
}
|
||||
)
|
||||
|
||||
var inputs = [][]byte{
|
||||
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"Items": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||
[]byte(`{"items": {"key":{"name": "foo"}}}`),
|
||||
}
|
||||
for _, input := range inputs {
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, 1, len(c.Items))
|
||||
for _, v := range c.Items {
|
||||
assert.Equal(t, "foo", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesWithMapTypeValueOfStruct(t *testing.T) {
|
||||
type (
|
||||
Value struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
Map map[string]Value
|
||||
|
||||
Config struct {
|
||||
Map
|
||||
}
|
||||
)
|
||||
|
||||
var inputs = [][]byte{
|
||||
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"Map": {"Key":{"Name": "foo"}}}`),
|
||||
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||
[]byte(`{"map": {"key":{"name": "foo"}}}`),
|
||||
}
|
||||
for _, input := range inputs {
|
||||
var c Config
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, 1, len(c.Map))
|
||||
for _, v := range c.Map {
|
||||
assert.Equal(t, "foo", v.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FieldOverwrite(t *testing.T) {
|
||||
t.Run("normal", func(t *testing.T) {
|
||||
type Base struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*Base
|
||||
Name2 string
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*Base
|
||||
Name2 *string
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello", "Name2": "world"}`)
|
||||
assert.NoError(t, LoadFromJsonBytes(input, val))
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
|
||||
t.Run("Inherit Override", func(t *testing.T) {
|
||||
type Base struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
Base
|
||||
Name string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
Base
|
||||
Name int
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*Base
|
||||
Name int
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*Base
|
||||
Name *string
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
|
||||
t.Run("Inherit more", func(t *testing.T) {
|
||||
type Base1 struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type St0 struct {
|
||||
Base1
|
||||
Name string
|
||||
}
|
||||
|
||||
type St1 struct {
|
||||
St0
|
||||
Name string
|
||||
}
|
||||
|
||||
type St2 struct {
|
||||
St0
|
||||
Name int
|
||||
}
|
||||
|
||||
type St3 struct {
|
||||
*St0
|
||||
Name int
|
||||
}
|
||||
|
||||
type St4 struct {
|
||||
*St0
|
||||
Name *int
|
||||
}
|
||||
|
||||
validate := func(val interface{}) {
|
||||
input := []byte(`{"Name": "hello"}`)
|
||||
err := LoadFromJsonBytes(input, val)
|
||||
assert.ErrorAs(t, err, &dupErr)
|
||||
assert.Equal(t, newDupKeyError("name").Error(), err.Error())
|
||||
}
|
||||
|
||||
validate(&St0{})
|
||||
validate(&St1{})
|
||||
validate(&St2{})
|
||||
validate(&St3{})
|
||||
validate(&St4{})
|
||||
})
|
||||
}
|
||||
|
||||
func TestFieldOverwriteComplicated(t *testing.T) {
|
||||
t.Run("double maps", func(t *testing.T) {
|
||||
type (
|
||||
Base1 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
Base2 struct {
|
||||
Values map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Base1
|
||||
Base2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Values": {"Key": "Value"}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("merge children", func(t *testing.T) {
|
||||
type (
|
||||
Inner1 struct {
|
||||
Name string
|
||||
}
|
||||
Inner2 struct {
|
||||
Age int
|
||||
}
|
||||
Base1 struct {
|
||||
Inner Inner1
|
||||
}
|
||||
Base2 struct {
|
||||
Inner Inner2
|
||||
}
|
||||
Config struct {
|
||||
Base1
|
||||
Base2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Inner": {"Name": "foo", "Age": 10}}`)
|
||||
if assert.NoError(t, LoadFromJsonBytes(input, &c)) {
|
||||
assert.Equal(t, "foo", c.Base1.Inner.Name)
|
||||
assert.Equal(t, 10, c.Base2.Inner.Age)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("overwritten maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Inner": {"Map": {"Key": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten nested maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle1 struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
Middle2 struct {
|
||||
Map map[string]string
|
||||
Inner
|
||||
}
|
||||
Config struct {
|
||||
Middle1
|
||||
Middle2
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Middle1": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten outer/inner maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle struct {
|
||||
Inner
|
||||
Map map[string]string
|
||||
}
|
||||
Config struct {
|
||||
Middle
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Middle": {"Inner": {"Map": {"Key": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten anonymous maps", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Map map[string]string
|
||||
}
|
||||
Middle struct {
|
||||
Inner
|
||||
Map map[string]string
|
||||
}
|
||||
Elem map[string]Middle
|
||||
Config struct {
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Inner": {"Map": {"Key": "Value"}}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten primitive and map", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Value string
|
||||
}
|
||||
Elem map[string]Inner
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten map and slice", func(t *testing.T) {
|
||||
type (
|
||||
Inner struct {
|
||||
Value string
|
||||
}
|
||||
Elem []Inner
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten map and string", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Config struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Elem": {"Key": {"Value": "Value"}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadNamedFieldOverwritten(t *testing.T) {
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Val Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": {"Elem": {"Key": {"Value": "Value"}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named []struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals []Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named map[string]struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals map[string]Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": {"Key": {"Elem": {"Key": {"Value": "Value"}}}}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named *struct", func(t *testing.T) {
|
||||
type (
|
||||
Elem string
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem
|
||||
}
|
||||
Config struct {
|
||||
Vals *Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Vals": [{"Elem": {"Key": {"Value": "Value"}}}]}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type (
|
||||
Named struct {
|
||||
Elem string
|
||||
}
|
||||
Base struct {
|
||||
Named
|
||||
Elem Named
|
||||
}
|
||||
Config struct {
|
||||
Val Base
|
||||
}
|
||||
)
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": {"Elem": "Value"}}`)
|
||||
assert.ErrorAs(t, LoadFromJsonBytes(input, &c), &dupErr)
|
||||
})
|
||||
|
||||
t.Run("overwritten named struct", func(t *testing.T) {
|
||||
type Config struct {
|
||||
Val chan int
|
||||
}
|
||||
|
||||
var c Config
|
||||
input := []byte(`{"Val": 1}`)
|
||||
assert.Error(t, LoadFromJsonBytes(input, &c))
|
||||
})
|
||||
}
|
||||
|
||||
func createTempFile(ext, text string) (string, error) {
|
||||
tmpfile, err := os.CreateTemp(os.TempDir(), hash.Md5Hex([]byte(text))+"*"+ext)
|
||||
if err != nil {
|
||||
@@ -160,3 +1039,55 @@ func createTempFile(ext, text string) (string, error) {
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
func TestFillDefaultUnmarshal(t *testing.T) {
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := FillDefault(St{})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not nil", func(t *testing.T) {
|
||||
type St struct{}
|
||||
err := FillDefault(&St{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("default", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st St
|
||||
err := FillDefault(&st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
})
|
||||
|
||||
t.Run("env", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
C string `json:",env=TEST_C"`
|
||||
}
|
||||
t.Setenv("TEST_C", "c")
|
||||
|
||||
var st St
|
||||
err := FillDefault(&st)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, st.A, "a")
|
||||
assert.Equal(t, st.C, "c")
|
||||
})
|
||||
|
||||
t.Run("has vaue", func(t *testing.T) {
|
||||
type St struct {
|
||||
A string `json:",default=a"`
|
||||
B string
|
||||
}
|
||||
var st = St{
|
||||
A: "b",
|
||||
}
|
||||
err := FillDefault(&st)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
// PropertyError represents a configuration error message.
|
||||
type PropertyError struct {
|
||||
error
|
||||
message string
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
```go
|
||||
type RestfulConf struct {
|
||||
ServiceName string `json:",env=SERVICE_NAME"` // read from env automatically
|
||||
Host string `json:",default=0.0.0.0"`
|
||||
Port int
|
||||
LogMode string `json:",options=[file,console]"`
|
||||
@@ -21,20 +22,20 @@ type RestfulConf struct {
|
||||
|
||||
```yaml
|
||||
# most fields are optional or have default values
|
||||
Port: 8080
|
||||
LogMode: console
|
||||
port: 8080
|
||||
logMode: console
|
||||
# you can use env settings
|
||||
MaxBytes: ${MAX_BYTES}
|
||||
maxBytes: ${MAX_BYTES}
|
||||
```
|
||||
|
||||
- toml example
|
||||
|
||||
```toml
|
||||
# most fields are optional or have default values
|
||||
Port = 8_080
|
||||
LogMode = "console"
|
||||
port = 8_080
|
||||
logMode = "console"
|
||||
# you can use env settings
|
||||
MaxBytes = "${MAX_BYTES}"
|
||||
maxBytes = "${MAX_BYTES}"
|
||||
```
|
||||
|
||||
3. Load the config from a file:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package discov
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/discov/internal"
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
@@ -51,12 +53,7 @@ func NewPublisher(endpoints []string, key, value string, opts ...PubOption) *Pub
|
||||
|
||||
// KeepAlive keeps key:value alive.
|
||||
func (p *Publisher) KeepAlive() error {
|
||||
cli, err := internal.GetRegistry().GetConn(p.endpoints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.lease, err = p.register(cli)
|
||||
cli, err := p.doRegister()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -83,6 +80,43 @@ func (p *Publisher) Stop() {
|
||||
p.quit.Close()
|
||||
}
|
||||
|
||||
func (p *Publisher) doKeepAlive() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
select {
|
||||
case <-p.quit.Done():
|
||||
return nil
|
||||
default:
|
||||
cli, err := p.doRegister()
|
||||
if err != nil {
|
||||
logx.Errorf("etcd publisher doRegister: %s", err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
if err := p.keepAliveAsync(cli); err != nil {
|
||||
logx.Errorf("etcd publisher keepAliveAsync: %s", err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) doRegister() (internal.EtcdClient, error) {
|
||||
cli, err := internal.GetRegistry().GetConn(p.endpoints)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.lease, err = p.register(cli)
|
||||
return cli, err
|
||||
}
|
||||
|
||||
func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
ch, err := cli.KeepAlive(cli.Ctx(), p.lease)
|
||||
if err != nil {
|
||||
@@ -95,8 +129,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
p.revoke(cli)
|
||||
if err := p.KeepAlive(); err != nil {
|
||||
logx.Errorf("KeepAlive: %s", err.Error())
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -105,8 +139,8 @@ func (p *Publisher) keepAliveAsync(cli internal.EtcdClient) error {
|
||||
p.revoke(cli)
|
||||
select {
|
||||
case <-p.resumeChan:
|
||||
if err := p.KeepAlive(); err != nil {
|
||||
logx.Errorf("KeepAlive: %s", err.Error())
|
||||
if err := p.doKeepAlive(); err != nil {
|
||||
logx.Errorf("etcd publisher KeepAlive: %s", err.Error())
|
||||
}
|
||||
return
|
||||
case <-p.quit.Done():
|
||||
@@ -141,7 +175,7 @@ func (p *Publisher) register(client internal.EtcdClient) (clientv3.LeaseID, erro
|
||||
|
||||
func (p *Publisher) revoke(cli internal.EtcdClient) {
|
||||
if _, err := cli.Revoke(cli.Ctx(), p.lease); err != nil {
|
||||
logx.Error(err)
|
||||
logx.Errorf("etcd publisher revoke: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ type (
|
||||
// SubOption defines the method to customize a Subscriber.
|
||||
SubOption func(sub *Subscriber)
|
||||
|
||||
// A Subscriber is used to subscribe the given key on a etcd cluster.
|
||||
// A Subscriber is used to subscribe the given key on an etcd cluster.
|
||||
Subscriber struct {
|
||||
endpoints []string
|
||||
exclusive bool
|
||||
|
||||
@@ -53,10 +53,11 @@ func TestChunkExecutorFlushInterval(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestChunkExecutorEmpty(t *testing.T) {
|
||||
NewChunkExecutor(func(items []interface{}) {
|
||||
executor := NewChunkExecutor(func(items []interface{}) {
|
||||
assert.Fail(t, "should not called")
|
||||
}, WithChunkBytes(10), WithFlushInterval(time.Millisecond))
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
executor.Wait()
|
||||
}
|
||||
|
||||
func TestChunkExecutorFlush(t *testing.T) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
@@ -67,6 +68,7 @@ func TestPeriodicalExecutor_QuitGoroutine(t *testing.T) {
|
||||
ticker.Tick()
|
||||
ticker.Wait(time.Millisecond * idleRound)
|
||||
assert.Equal(t, routines, runtime.NumGoroutine())
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
func TestPeriodicalExecutor_Bulk(t *testing.T) {
|
||||
|
||||
15
core/fs/files_test.go
Normal file
15
core/fs/files_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package fs
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCloseOnExec(t *testing.T) {
|
||||
file := os.NewFile(0, os.DevNull)
|
||||
assert.NotPanics(t, func() {
|
||||
CloseOnExec(file)
|
||||
})
|
||||
}
|
||||
@@ -328,7 +328,7 @@ func (s Stream) Parallel(fn ParallelFunc, opts ...Option) {
|
||||
}, opts...).Done()
|
||||
}
|
||||
|
||||
// Reduce is a utility method to let the caller deal with the underlying channel.
|
||||
// Reduce is an utility method to let the caller deal with the underlying channel.
|
||||
func (s Stream) Reduce(fn ReduceFunc) (interface{}, error) {
|
||||
return fn(s.source)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/mapping"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -183,5 +182,5 @@ func innerRepr(node interface{}) string {
|
||||
}
|
||||
|
||||
func repr(node interface{}) string {
|
||||
return mapping.Repr(node)
|
||||
return lang.Repr(node)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ func (nopCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NopCloser returns a io.WriteCloser that does nothing on calling Close.
|
||||
// NopCloser returns an io.WriteCloser that does nothing on calling Close.
|
||||
func NopCloser(w io.Writer) io.WriteCloser {
|
||||
return nopCloser{w}
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
package jsontype
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
)
|
||||
|
||||
// MilliTime represents time.Time that works better with mongodb.
|
||||
type MilliTime struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
// MarshalJSON marshals mt to json bytes.
|
||||
func (mt MilliTime) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(mt.Milli())
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals data into mt.
|
||||
func (mt *MilliTime) UnmarshalJSON(data []byte) error {
|
||||
var milli int64
|
||||
if err := json.Unmarshal(data, &milli); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mt.Time = time.Unix(0, milli*int64(time.Millisecond))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBSON returns BSON base on mt.
|
||||
func (mt MilliTime) GetBSON() (interface{}, error) {
|
||||
return mt.Time, nil
|
||||
}
|
||||
|
||||
// SetBSON sets raw into mt.
|
||||
func (mt *MilliTime) SetBSON(raw bson.Raw) error {
|
||||
return raw.Unmarshal(&mt.Time)
|
||||
}
|
||||
|
||||
// Milli returns milliseconds for mt.
|
||||
func (mt MilliTime) Milli() int64 {
|
||||
return mt.UnixNano() / int64(time.Millisecond)
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
package jsontype
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMilliTime_GetBSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got, err := MilliTime{test.tm}.GetBSON()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, test.tm, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMilliTime_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b, err := MilliTime{test.tm}.MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, strconv.FormatInt(test.tm.UnixNano()/1e6, 10), string(b))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMilliTime_Milli(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
n := MilliTime{test.tm}.Milli()
|
||||
assert.Equal(t, test.tm.UnixNano()/1e6, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMilliTime_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tm time.Time
|
||||
}{
|
||||
{
|
||||
name: "now",
|
||||
tm: time.Now(),
|
||||
},
|
||||
{
|
||||
name: "future",
|
||||
tm: time.Now().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var mt MilliTime
|
||||
s := strconv.FormatInt(test.tm.UnixNano()/1e6, 10)
|
||||
err := mt.UnmarshalJSON([]byte(s))
|
||||
assert.Nil(t, err)
|
||||
s1, err := mt.MarshalJSON()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, s, string(s1))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalWithError(t *testing.T) {
|
||||
var mt MilliTime
|
||||
assert.NotNil(t, mt.UnmarshalJSON([]byte("hello")))
|
||||
}
|
||||
|
||||
func TestSetBSON(t *testing.T) {
|
||||
data, err := bson.Marshal(time.Now())
|
||||
assert.Nil(t, err)
|
||||
|
||||
var raw bson.Raw
|
||||
assert.Nil(t, bson.Unmarshal(data, &raw))
|
||||
|
||||
var mt MilliTime
|
||||
assert.Nil(t, mt.SetBSON(raw))
|
||||
assert.NotNil(t, mt.SetBSON(bson.Raw{}))
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
package lang
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Placeholder is a placeholder object that can be used globally.
|
||||
var Placeholder PlaceholderType
|
||||
|
||||
@@ -9,3 +15,64 @@ type (
|
||||
// PlaceholderType represents a placeholder type.
|
||||
PlaceholderType = struct{}
|
||||
)
|
||||
|
||||
// Repr returns the string representation of v.
|
||||
func Repr(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// if func (v *Type) String() string, we can't use Elem()
|
||||
switch vt := v.(type) {
|
||||
case fmt.Stringer:
|
||||
return vt.String()
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(v)
|
||||
for val.Kind() == reflect.Ptr && !val.IsNil() {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
return reprOfValue(val)
|
||||
}
|
||||
|
||||
func reprOfValue(val reflect.Value) string {
|
||||
switch vt := val.Interface().(type) {
|
||||
case bool:
|
||||
return strconv.FormatBool(vt)
|
||||
case error:
|
||||
return vt.Error()
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(vt), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(vt, 'f', -1, 64)
|
||||
case fmt.Stringer:
|
||||
return vt.String()
|
||||
case int:
|
||||
return strconv.Itoa(vt)
|
||||
case int8:
|
||||
return strconv.Itoa(int(vt))
|
||||
case int16:
|
||||
return strconv.Itoa(int(vt))
|
||||
case int32:
|
||||
return strconv.Itoa(int(vt))
|
||||
case int64:
|
||||
return strconv.FormatInt(vt, 10)
|
||||
case string:
|
||||
return vt
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(vt, 10)
|
||||
case []byte:
|
||||
return string(vt)
|
||||
default:
|
||||
return fmt.Sprint(val.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
156
core/lang/lang_test.go
Normal file
156
core/lang/lang_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package lang
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRepr(t *testing.T) {
|
||||
var (
|
||||
f32 float32 = 1.1
|
||||
f64 = 2.2
|
||||
i8 int8 = 1
|
||||
i16 int16 = 2
|
||||
i32 int32 = 3
|
||||
i64 int64 = 4
|
||||
u8 uint8 = 5
|
||||
u16 uint16 = 6
|
||||
u32 uint32 = 7
|
||||
u64 uint64 = 8
|
||||
)
|
||||
tests := []struct {
|
||||
v interface{}
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
{
|
||||
mockStringable{},
|
||||
"mocked",
|
||||
},
|
||||
{
|
||||
new(mockStringable),
|
||||
"mocked",
|
||||
},
|
||||
{
|
||||
newMockPtr(),
|
||||
"mockptr",
|
||||
},
|
||||
{
|
||||
&mockOpacity{
|
||||
val: 1,
|
||||
},
|
||||
"{1}",
|
||||
},
|
||||
{
|
||||
true,
|
||||
"true",
|
||||
},
|
||||
{
|
||||
false,
|
||||
"false",
|
||||
},
|
||||
{
|
||||
f32,
|
||||
"1.1",
|
||||
},
|
||||
{
|
||||
f64,
|
||||
"2.2",
|
||||
},
|
||||
{
|
||||
i8,
|
||||
"1",
|
||||
},
|
||||
{
|
||||
i16,
|
||||
"2",
|
||||
},
|
||||
{
|
||||
i32,
|
||||
"3",
|
||||
},
|
||||
{
|
||||
i64,
|
||||
"4",
|
||||
},
|
||||
{
|
||||
u8,
|
||||
"5",
|
||||
},
|
||||
{
|
||||
u16,
|
||||
"6",
|
||||
},
|
||||
{
|
||||
u32,
|
||||
"7",
|
||||
},
|
||||
{
|
||||
u64,
|
||||
"8",
|
||||
},
|
||||
{
|
||||
[]byte(`abcd`),
|
||||
"abcd",
|
||||
},
|
||||
{
|
||||
mockOpacity{val: 1},
|
||||
"{1}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.expect, func(t *testing.T) {
|
||||
assert.Equal(t, test.expect, Repr(test.v))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReprOfValue(t *testing.T) {
|
||||
t.Run("error", func(t *testing.T) {
|
||||
assert.Equal(t, "error", reprOfValue(reflect.ValueOf(errors.New("error"))))
|
||||
})
|
||||
|
||||
t.Run("stringer", func(t *testing.T) {
|
||||
assert.Equal(t, "1.23", reprOfValue(reflect.ValueOf(json.Number("1.23"))))
|
||||
})
|
||||
|
||||
t.Run("int", func(t *testing.T) {
|
||||
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(1)))
|
||||
})
|
||||
|
||||
t.Run("int", func(t *testing.T) {
|
||||
assert.Equal(t, "1", reprOfValue(reflect.ValueOf("1")))
|
||||
})
|
||||
|
||||
t.Run("int", func(t *testing.T) {
|
||||
assert.Equal(t, "1", reprOfValue(reflect.ValueOf(uint(1))))
|
||||
})
|
||||
}
|
||||
|
||||
type mockStringable struct{}
|
||||
|
||||
func (m mockStringable) String() string {
|
||||
return "mocked"
|
||||
}
|
||||
|
||||
type mockPtr struct{}
|
||||
|
||||
func newMockPtr() *mockPtr {
|
||||
return new(mockPtr)
|
||||
}
|
||||
|
||||
func (m *mockPtr) String() string {
|
||||
return "mockptr"
|
||||
}
|
||||
|
||||
type mockOpacity struct {
|
||||
val int
|
||||
}
|
||||
@@ -27,6 +27,26 @@ func Close() error {
|
||||
return logx.Close()
|
||||
}
|
||||
|
||||
// Debug writes v into access log.
|
||||
func Debug(ctx context.Context, v ...interface{}) {
|
||||
getLogger(ctx).Debug(v...)
|
||||
}
|
||||
|
||||
// Debugf writes v with format into access log.
|
||||
func Debugf(ctx context.Context, format string, v ...interface{}) {
|
||||
getLogger(ctx).Debugf(format, v...)
|
||||
}
|
||||
|
||||
// Debugv writes v into access log with json content.
|
||||
func Debugv(ctx context.Context, v interface{}) {
|
||||
getLogger(ctx).Debugv(v)
|
||||
}
|
||||
|
||||
// Debugw writes msg along with fields into access log.
|
||||
func Debugw(ctx context.Context, msg string, fields ...LogField) {
|
||||
getLogger(ctx).Debugw(msg, fields...)
|
||||
}
|
||||
|
||||
// Error writes v into error log.
|
||||
func Error(ctx context.Context, v ...interface{}) {
|
||||
getLogger(ctx).Error(v...)
|
||||
|
||||
@@ -140,6 +140,54 @@ func TestInfow(t *testing.T) {
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebug(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debug(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugf(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugf(context.Background(), "foo %s", "bar")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugv(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugv(context.Background(), "foo")
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestDebugw(t *testing.T) {
|
||||
var buf strings.Builder
|
||||
writer := logx.NewWriter(&buf)
|
||||
old := logx.Reset()
|
||||
logx.SetWriter(writer)
|
||||
defer logx.SetWriter(old)
|
||||
|
||||
file, line := getFileLine()
|
||||
Debugw(context.Background(), "foo", Field("a", "b"))
|
||||
assert.True(t, strings.Contains(buf.String(), fmt.Sprintf("%s:%d", file, line+1)))
|
||||
}
|
||||
|
||||
func TestMust(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
Must(nil)
|
||||
|
||||
@@ -2,15 +2,34 @@ package logx
|
||||
|
||||
// A LogConf is a logging config.
|
||||
type LogConf struct {
|
||||
ServiceName string `json:",optional"`
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
TimeFormat string `json:",optional"`
|
||||
Path string `json:",default=logs"`
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
Compress bool `json:",optional"`
|
||||
KeepDays int `json:",optional"`
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// ServiceName represents the service name.
|
||||
ServiceName string `json:",optional"`
|
||||
// Mode represents the logging mode, default is `console`.
|
||||
// console: log to console.
|
||||
// file: log to file.
|
||||
// volume: used in k8s, prepend the hostname to the log file name.
|
||||
Mode string `json:",default=console,options=[console,file,volume]"`
|
||||
// Encoding represents the encoding type, default is `json`.
|
||||
// json: json encoding.
|
||||
// plain: plain text encoding, typically used in development.
|
||||
Encoding string `json:",default=json,options=[json,plain]"`
|
||||
// TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`.
|
||||
TimeFormat string `json:",optional"`
|
||||
// Path represents the log file path, default is `logs`.
|
||||
Path string `json:",default=logs"`
|
||||
// Level represents the log level, default is `info`.
|
||||
Level string `json:",default=info,options=[debug,info,error,severe]"`
|
||||
// MaxContentLength represents the max content bytes, default is no limit.
|
||||
MaxContentLength uint32 `json:",optional"`
|
||||
// Compress represents whether to compress the log file, default is `false`.
|
||||
Compress bool `json:",optional"`
|
||||
// Stdout represents whether to log statistics, default is `true`.
|
||||
Stat bool `json:",default=true"`
|
||||
// KeepDays represents how many days the log files will be kept. Default to keep all files.
|
||||
// Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`.
|
||||
KeepDays int `json:",optional"`
|
||||
// StackCooldownMillis represents the cooldown time for stack logging, default is 100ms.
|
||||
StackCooldownMillis int `json:",default=100"`
|
||||
// MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever.
|
||||
// Only take effect when RotationRuleType is `size`.
|
||||
// Even thougth `MaxBackups` sets 0, log files will still be removed
|
||||
|
||||
@@ -31,7 +31,10 @@ func AddGlobalFields(fields ...LogField) {
|
||||
func ContextWithFields(ctx context.Context, fields ...LogField) context.Context {
|
||||
if val := ctx.Value(fieldsContextKey); val != nil {
|
||||
if arr, ok := val.([]LogField); ok {
|
||||
return context.WithValue(ctx, fieldsContextKey, append(arr, fields...))
|
||||
allFields := make([]LogField, 0, len(arr)+len(fields))
|
||||
allFields = append(allFields, arr...)
|
||||
allFields = append(allFields, fields...)
|
||||
return context.WithValue(ctx, fieldsContextKey, allFields)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -67,6 +68,22 @@ func TestWithFieldsAppend(t *testing.T) {
|
||||
}, fields)
|
||||
}
|
||||
|
||||
func TestWithFieldsAppendCopy(t *testing.T) {
|
||||
const count = 10
|
||||
ctx := context.Background()
|
||||
for i := 0; i < count; i++ {
|
||||
ctx = ContextWithFields(ctx, Field(strconv.Itoa(i), 1))
|
||||
}
|
||||
|
||||
af := Field("foo", 1)
|
||||
bf := Field("bar", 2)
|
||||
ctxa := ContextWithFields(ctx, af)
|
||||
ctxb := ContextWithFields(ctx, bf)
|
||||
|
||||
assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count])
|
||||
assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count])
|
||||
}
|
||||
|
||||
func BenchmarkAtomicValue(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ var (
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07:00"
|
||||
logLevel uint32
|
||||
encoding uint32 = jsonEncodingType
|
||||
// maxContentLength is used to truncate the log content, 0 for not truncating.
|
||||
maxContentLength uint32
|
||||
// use uint32 for atomic operations
|
||||
disableLog uint32
|
||||
disableStat uint32
|
||||
@@ -230,10 +232,16 @@ func SetUp(c LogConf) (err error) {
|
||||
setupOnce.Do(func() {
|
||||
setupLogLevel(c)
|
||||
|
||||
if !c.Stat {
|
||||
DisableStat()
|
||||
}
|
||||
|
||||
if len(c.TimeFormat) > 0 {
|
||||
timeFormat = c.TimeFormat
|
||||
}
|
||||
|
||||
atomic.StoreUint32(&maxContentLength, c.MaxContentLength)
|
||||
|
||||
switch c.Encoding {
|
||||
case plainEncoding:
|
||||
atomic.StoreUint32(&encoding, plainEncodingType)
|
||||
|
||||
@@ -529,9 +529,9 @@ func TestSetLevel(t *testing.T) {
|
||||
|
||||
func TestSetLevelTwiceWithMode(t *testing.T) {
|
||||
testModes := []string{
|
||||
"mode",
|
||||
"console",
|
||||
"volumn",
|
||||
"mode",
|
||||
}
|
||||
w := new(mockWriter)
|
||||
old := writer.Swap(w)
|
||||
@@ -791,9 +791,12 @@ func doTestStructedLogConsole(t *testing.T, w *mockWriter, write func(...interfa
|
||||
func testSetLevelTwiceWithMode(t *testing.T, mode string, w *mockWriter) {
|
||||
writer.Store(nil)
|
||||
SetUp(LogConf{
|
||||
Mode: mode,
|
||||
Level: "error",
|
||||
Path: "/dev/null",
|
||||
Mode: mode,
|
||||
Level: "debug",
|
||||
Path: "/dev/null",
|
||||
Encoding: plainEncoding,
|
||||
Stat: false,
|
||||
TimeFormat: time.RFC3339,
|
||||
})
|
||||
SetUp(LogConf{
|
||||
Mode: mode,
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"github.com/zeromicro/go-zero/internal/trace"
|
||||
)
|
||||
|
||||
// WithCallerSkip returns a Logger with given caller skip.
|
||||
@@ -136,12 +136,12 @@ func (l *richLogger) buildFields(fields ...LogField) []LogField {
|
||||
return fields
|
||||
}
|
||||
|
||||
traceID := traceIdFromContext(l.ctx)
|
||||
traceID := trace.TraceIDFromContext(l.ctx)
|
||||
if len(traceID) > 0 {
|
||||
fields = append(fields, Field(traceKey, traceID))
|
||||
}
|
||||
|
||||
spanID := spanIdFromContext(l.ctx)
|
||||
spanID := trace.SpanIDFromContext(l.ctx)
|
||||
if len(spanID) > 0 {
|
||||
fields = append(fields, Field(spanKey, spanID))
|
||||
}
|
||||
@@ -179,21 +179,3 @@ func (l *richLogger) slow(v interface{}, fields ...LogField) {
|
||||
getWriter().Slow(v, l.buildFields(fields...)...)
|
||||
}
|
||||
}
|
||||
|
||||
func spanIdFromContext(ctx context.Context) string {
|
||||
spanCtx := trace.SpanContextFromContext(ctx)
|
||||
if spanCtx.HasSpanID() {
|
||||
return spanCtx.SpanID().String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func traceIdFromContext(ctx context.Context) string {
|
||||
spanCtx := trace.SpanContextFromContext(ctx)
|
||||
if spanCtx.HasTraceID() {
|
||||
return spanCtx.TraceID().String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -284,7 +284,7 @@ func (l *RotateLogger) getBackupFilename() string {
|
||||
func (l *RotateLogger) init() error {
|
||||
l.backup = l.rule.BackupFileName()
|
||||
|
||||
if _, err := os.Stat(l.filename); err != nil {
|
||||
if fileInfo, err := os.Stat(l.filename); err != nil {
|
||||
basePath := path.Dir(l.filename)
|
||||
if _, err = os.Stat(basePath); err != nil {
|
||||
if err = os.MkdirAll(basePath, defaultDirMode); err != nil {
|
||||
@@ -295,8 +295,11 @@ func (l *RotateLogger) init() error {
|
||||
if l.fp, err = os.Create(l.filename); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil {
|
||||
return err
|
||||
}
|
||||
l.currentSize = fileInfo.Size()
|
||||
}
|
||||
|
||||
fs.CloseOnExec(l.fp)
|
||||
|
||||
@@ -16,13 +16,13 @@ const (
|
||||
const (
|
||||
jsonEncodingType = iota
|
||||
plainEncodingType
|
||||
|
||||
plainEncoding = "plain"
|
||||
plainEncodingSep = '\t'
|
||||
sizeRotationRule = "size"
|
||||
)
|
||||
|
||||
const (
|
||||
plainEncoding = "plain"
|
||||
plainEncodingSep = '\t'
|
||||
sizeRotationRule = "size"
|
||||
|
||||
accessFilename = "access.log"
|
||||
errorFilename = "error.log"
|
||||
severeFilename = "severe.log"
|
||||
@@ -53,6 +53,7 @@ const (
|
||||
spanKey = "span"
|
||||
timestampKey = "@timestamp"
|
||||
traceKey = "trace"
|
||||
truncatedKey = "truncated"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -60,4 +61,6 @@ var (
|
||||
ErrLogPathNotSet = errors.New("log path must be set")
|
||||
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
||||
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
||||
|
||||
truncatedField = Field(truncatedKey, true)
|
||||
)
|
||||
|
||||
@@ -253,11 +253,11 @@ func (n nopWriter) Stack(_ interface{}) {
|
||||
func (n nopWriter) Stat(_ interface{}, _ ...LogField) {
|
||||
}
|
||||
|
||||
func buildFields(fields ...LogField) []string {
|
||||
func buildPlainFields(fields ...LogField) []string {
|
||||
var items []string
|
||||
|
||||
for _, field := range fields {
|
||||
items = append(items, fmt.Sprintf("%s=%v", field.Key, field.Value))
|
||||
items = append(items, fmt.Sprintf("%s=%+v", field.Key, field.Value))
|
||||
}
|
||||
|
||||
return items
|
||||
@@ -269,15 +269,29 @@ func combineGlobalFields(fields []LogField) []LogField {
|
||||
return fields
|
||||
}
|
||||
|
||||
return append(globals.([]LogField), fields...)
|
||||
gf := globals.([]LogField)
|
||||
ret := make([]LogField, 0, len(gf)+len(fields))
|
||||
ret = append(ret, gf...)
|
||||
ret = append(ret, fields...)
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func output(writer io.Writer, level string, val interface{}, fields ...LogField) {
|
||||
// only truncate string content, don't know how to truncate the values of other types.
|
||||
if v, ok := val.(string); ok {
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
if maxLen > 0 && len(v) > int(maxLen) {
|
||||
val = v[:maxLen]
|
||||
fields = append(fields, truncatedField)
|
||||
}
|
||||
}
|
||||
|
||||
fields = combineGlobalFields(fields)
|
||||
|
||||
switch atomic.LoadUint32(&encoding) {
|
||||
case plainEncodingType:
|
||||
writePlainAny(writer, level, val, buildFields(fields...)...)
|
||||
writePlainAny(writer, level, val, buildPlainFields(fields...)...)
|
||||
default:
|
||||
entry := make(logEntry)
|
||||
for _, field := range fields {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -107,7 +108,7 @@ func TestNopWriter(t *testing.T) {
|
||||
w.Stack("foo")
|
||||
w.Stat("foo")
|
||||
w.Slow("foo")
|
||||
w.Close()
|
||||
_ = w.Close()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -157,9 +158,40 @@ func TestWritePlainAny(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestLogWithLimitContentLength(t *testing.T) {
|
||||
maxLen := atomic.LoadUint32(&maxContentLength)
|
||||
atomic.StoreUint32(&maxContentLength, 10)
|
||||
|
||||
t.Cleanup(func() {
|
||||
atomic.StoreUint32(&maxContentLength, maxLen)
|
||||
})
|
||||
|
||||
t.Run("alert", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
w := NewWriter(&buf)
|
||||
w.Info("1234567890")
|
||||
var v1 mockedEntry
|
||||
if err := json.Unmarshal(buf.Bytes(), &v1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, "1234567890", v1.Content)
|
||||
assert.False(t, v1.Truncated)
|
||||
|
||||
buf.Reset()
|
||||
var v2 mockedEntry
|
||||
w.Info("12345678901")
|
||||
if err := json.Unmarshal(buf.Bytes(), &v2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, "1234567890", v2.Content)
|
||||
assert.True(t, v2.Truncated)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedEntry struct {
|
||||
Level string `json:"level"`
|
||||
Content string `json:"content"`
|
||||
Level string `json:"level"`
|
||||
Content string `json:"content"`
|
||||
Truncated bool `json:"truncated"`
|
||||
}
|
||||
|
||||
type easyToCloseWriter struct{}
|
||||
|
||||
@@ -8,10 +8,12 @@ type (
|
||||
// use context and OptionalDep option to determine the value of Optional
|
||||
// nothing to do with context.Context
|
||||
fieldOptionsWithContext struct {
|
||||
Inherit bool
|
||||
FromString bool
|
||||
Optional bool
|
||||
Options []string
|
||||
Default string
|
||||
EnvVar string
|
||||
Range *numberRange
|
||||
}
|
||||
|
||||
@@ -40,6 +42,10 @@ func (o *fieldOptionsWithContext) getDefault() (string, bool) {
|
||||
return o.Default, len(o.Default) > 0
|
||||
}
|
||||
|
||||
func (o *fieldOptionsWithContext) inherit() bool {
|
||||
return o != nil && o.Inherit
|
||||
}
|
||||
|
||||
func (o *fieldOptionsWithContext) optional() bool {
|
||||
return o != nil && o.Optional
|
||||
}
|
||||
@@ -101,5 +107,6 @@ func (o *fieldOptions) toOptionsWithContext(key string, m Valuer, fullName strin
|
||||
Optional: optional,
|
||||
Options: o.Options,
|
||||
Default: o.Default,
|
||||
EnvVar: o.EnvVar,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -11,22 +11,30 @@ const jsonTagKey = "json"
|
||||
var jsonUnmarshaler = NewUnmarshaler(jsonTagKey)
|
||||
|
||||
// UnmarshalJsonBytes unmarshals content into v.
|
||||
func UnmarshalJsonBytes(content []byte, v interface{}) error {
|
||||
return unmarshalJsonBytes(content, v, jsonUnmarshaler)
|
||||
func UnmarshalJsonBytes(content []byte, v interface{}, opts ...UnmarshalOption) error {
|
||||
return unmarshalJsonBytes(content, v, getJsonUnmarshaler(opts...))
|
||||
}
|
||||
|
||||
// UnmarshalJsonMap unmarshals content from m into v.
|
||||
func UnmarshalJsonMap(m map[string]interface{}, v interface{}) error {
|
||||
return jsonUnmarshaler.Unmarshal(m, v)
|
||||
func UnmarshalJsonMap(m map[string]interface{}, v interface{}, opts ...UnmarshalOption) error {
|
||||
return getJsonUnmarshaler(opts...).Unmarshal(m, v)
|
||||
}
|
||||
|
||||
// UnmarshalJsonReader unmarshals content from reader into v.
|
||||
func UnmarshalJsonReader(reader io.Reader, v interface{}) error {
|
||||
return unmarshalJsonReader(reader, v, jsonUnmarshaler)
|
||||
func UnmarshalJsonReader(reader io.Reader, v interface{}, opts ...UnmarshalOption) error {
|
||||
return unmarshalJsonReader(reader, v, getJsonUnmarshaler(opts...))
|
||||
}
|
||||
|
||||
func getJsonUnmarshaler(opts ...UnmarshalOption) *Unmarshaler {
|
||||
if len(opts) > 0 {
|
||||
return NewUnmarshaler(jsonTagKey, opts...)
|
||||
}
|
||||
|
||||
return jsonUnmarshaler
|
||||
}
|
||||
|
||||
func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error {
|
||||
var m map[string]interface{}
|
||||
var m interface{}
|
||||
if err := jsonx.Unmarshal(content, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -35,7 +43,7 @@ func unmarshalJsonBytes(content []byte, v interface{}, unmarshaler *Unmarshaler)
|
||||
}
|
||||
|
||||
func unmarshalJsonReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {
|
||||
var m map[string]interface{}
|
||||
var m interface{}
|
||||
if err := jsonx.UnmarshalFromReader(reader, &m); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -856,8 +856,7 @@ func TestUnmarshalBytesError(t *testing.T) {
|
||||
}
|
||||
|
||||
err := UnmarshalJsonBytes([]byte(payload), &v)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), payload))
|
||||
assert.Equal(t, errTypeMismatch, err)
|
||||
}
|
||||
|
||||
func TestUnmarshalReaderError(t *testing.T) {
|
||||
@@ -867,9 +866,7 @@ func TestUnmarshalReaderError(t *testing.T) {
|
||||
Any string
|
||||
}
|
||||
|
||||
err := UnmarshalJsonReader(reader, &v)
|
||||
assert.NotNil(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), payload))
|
||||
assert.Equal(t, errTypeMismatch, UnmarshalJsonReader(reader, &v))
|
||||
}
|
||||
|
||||
func TestUnmarshalMap(t *testing.T) {
|
||||
@@ -900,7 +897,9 @@ func TestUnmarshalMap(t *testing.T) {
|
||||
Any string `json:",optional"`
|
||||
}
|
||||
|
||||
err := UnmarshalJsonMap(m, &v)
|
||||
err := UnmarshalJsonMap(m, &v, WithCanonicalKeyFunc(func(s string) string {
|
||||
return s
|
||||
}))
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, len(v.Any) == 0)
|
||||
})
|
||||
@@ -918,3 +917,26 @@ func TestUnmarshalMap(t *testing.T) {
|
||||
assert.Equal(t, "foo", v.Any)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonArray(t *testing.T) {
|
||||
var v []struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
body := `[{"name":"kevin", "age": 18}]`
|
||||
assert.NoError(t, UnmarshalJsonBytes([]byte(body), &v))
|
||||
assert.Equal(t, 1, len(v))
|
||||
assert.Equal(t, "kevin", v[0].Name)
|
||||
assert.Equal(t, 18, v[0].Age)
|
||||
}
|
||||
|
||||
func TestUnmarshalJsonBytesError(t *testing.T) {
|
||||
var v []struct {
|
||||
Name string `json:"name"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
assert.Error(t, UnmarshalJsonBytes([]byte((``)), &v))
|
||||
assert.Error(t, UnmarshalJsonReader(strings.NewReader(``), &v))
|
||||
}
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
package mapping
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"github.com/zeromicro/go-zero/internal/encoding"
|
||||
)
|
||||
|
||||
// UnmarshalTomlBytes unmarshals TOML bytes into the given v.
|
||||
func UnmarshalTomlBytes(content []byte, v interface{}) error {
|
||||
return UnmarshalTomlReader(bytes.NewReader(content), v)
|
||||
func UnmarshalTomlBytes(content []byte, v interface{}, opts ...UnmarshalOption) error {
|
||||
b, err := encoding.TomlToJson(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalJsonBytes(b, v, opts...)
|
||||
}
|
||||
|
||||
// UnmarshalTomlReader unmarshals TOML from the given io.Reader into the given v.
|
||||
func UnmarshalTomlReader(r io.Reader, v interface{}) error {
|
||||
var val interface{}
|
||||
if err := toml.NewDecoder(r).Decode(&val); err != nil {
|
||||
func UnmarshalTomlReader(r io.Reader, v interface{}, opts ...UnmarshalOption) error {
|
||||
b, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := json.NewEncoder(&buf).Encode(val); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalJsonReader(&buf, v)
|
||||
return UnmarshalTomlBytes(b, v, opts...)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package mapping
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,7 +19,7 @@ d = "abcd!@#$112"
|
||||
C string `json:"c"`
|
||||
D string `json:"d"`
|
||||
}
|
||||
assert.Nil(t, UnmarshalTomlBytes([]byte(input), &val))
|
||||
assert.NoError(t, UnmarshalTomlReader(strings.NewReader(input), &val))
|
||||
assert.Equal(t, "foo", val.A)
|
||||
assert.Equal(t, 1, val.B)
|
||||
assert.Equal(t, "${FOO}", val.C)
|
||||
@@ -37,5 +38,12 @@ d = "abcd!@#$112"
|
||||
C string `json:"c"`
|
||||
D string `json:"d"`
|
||||
}
|
||||
assert.NotNil(t, UnmarshalTomlBytes([]byte(input), &val))
|
||||
assert.Error(t, UnmarshalTomlReader(strings.NewReader(input), &val))
|
||||
}
|
||||
|
||||
func TestUnmarshalTomlBadReader(t *testing.T) {
|
||||
var val struct {
|
||||
A string `json:"a"`
|
||||
}
|
||||
assert.Error(t, UnmarshalTomlReader(new(badReader), &val))
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -10,11 +10,14 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/lang"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultOption = "default"
|
||||
envOption = "env"
|
||||
inheritOption = "inherit"
|
||||
stringOption = "string"
|
||||
optionalOption = "optional"
|
||||
optionsOption = "options"
|
||||
@@ -53,7 +56,7 @@ type (
|
||||
|
||||
// Deref dereferences a type, if pointer type, returns its element type.
|
||||
func Deref(t reflect.Type) reflect.Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
@@ -62,22 +65,17 @@ func Deref(t reflect.Type) reflect.Type {
|
||||
|
||||
// Repr returns the string representation of v.
|
||||
func Repr(v interface{}) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
return lang.Repr(v)
|
||||
}
|
||||
|
||||
// if func (v *Type) String() string, we can't use Elem()
|
||||
switch vt := v.(type) {
|
||||
case fmt.Stringer:
|
||||
return vt.String()
|
||||
}
|
||||
// SetValue sets target to value, pointers are processed automatically.
|
||||
func SetValue(tp reflect.Type, value, target reflect.Value) {
|
||||
value.Set(convertTypeOfPtr(tp, target))
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(v)
|
||||
if val.Kind() == reflect.Ptr && !val.IsNil() {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
return reprOfValue(val)
|
||||
// SetMapIndexValue sets target to value at key position, pointers are processed automatically.
|
||||
func SetMapIndexValue(tp reflect.Type, value, key, target reflect.Value) {
|
||||
value.SetMapIndex(key, convertTypeOfPtr(tp, target))
|
||||
}
|
||||
|
||||
// ValidatePtr validates v if it's a valid pointer.
|
||||
@@ -91,10 +89,17 @@ func ValidatePtr(v *reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertType(kind reflect.Kind, str string) (interface{}, error) {
|
||||
func convertTypeFromString(kind reflect.Kind, str string) (interface{}, error) {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
return str == "1" || strings.ToLower(str) == "true", nil
|
||||
switch strings.ToLower(str) {
|
||||
case "1", "true":
|
||||
return true, nil
|
||||
case "0", "false":
|
||||
return false, nil
|
||||
default:
|
||||
return false, errTypeMismatch
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
intValue, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
@@ -123,6 +128,23 @@ func convertType(kind reflect.Kind, str string) (interface{}, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func convertTypeOfPtr(tp reflect.Type, target reflect.Value) reflect.Value {
|
||||
// keep the original value is a pointer
|
||||
if tp.Kind() == reflect.Ptr && target.CanAddr() {
|
||||
tp = tp.Elem()
|
||||
target = target.Addr()
|
||||
}
|
||||
|
||||
for tp.Kind() == reflect.Ptr {
|
||||
p := reflect.New(target.Type())
|
||||
p.Elem().Set(target)
|
||||
target = p
|
||||
tp = tp.Elem()
|
||||
}
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
func doParseKeyAndOptions(field reflect.StructField, value string) (string, *fieldOptions, error) {
|
||||
segments := parseSegments(value)
|
||||
key := strings.TrimSpace(segments[0])
|
||||
@@ -215,8 +237,8 @@ func isRightInclude(b byte) (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func maybeNewValue(field reflect.StructField, value reflect.Value) {
|
||||
if field.Type.Kind() == reflect.Ptr && value.IsNil() {
|
||||
func maybeNewValue(fieldType reflect.Type, value reflect.Value) {
|
||||
if fieldType.Kind() == reflect.Ptr && value.IsNil() {
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
}
|
||||
}
|
||||
@@ -335,6 +357,8 @@ func parseNumberRange(str string) (*numberRange, error) {
|
||||
|
||||
func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
|
||||
switch {
|
||||
case option == inheritOption:
|
||||
fieldOpts.Inherit = true
|
||||
case option == stringOption:
|
||||
fieldOpts.FromString = true
|
||||
case strings.HasPrefix(option, optionalOption):
|
||||
@@ -351,26 +375,33 @@ func parseOption(fieldOpts *fieldOptions, fieldName, option string) error {
|
||||
case option == optionalOption:
|
||||
fieldOpts.Optional = true
|
||||
case strings.HasPrefix(option, optionsOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return fmt.Errorf("field %s has wrong options", fieldName)
|
||||
val, err := parseProperty(fieldName, optionsOption, option)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldOpts.Options = parseOptions(segs[1])
|
||||
fieldOpts.Options = parseOptions(val)
|
||||
case strings.HasPrefix(option, defaultOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return fmt.Errorf("field %s has wrong default option", fieldName)
|
||||
val, err := parseProperty(fieldName, defaultOption, option)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldOpts.Default = strings.TrimSpace(segs[1])
|
||||
fieldOpts.Default = val
|
||||
case strings.HasPrefix(option, envOption):
|
||||
val, err := parseProperty(fieldName, envOption, option)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldOpts.EnvVar = val
|
||||
case strings.HasPrefix(option, rangeOption):
|
||||
segs := strings.Split(option, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return fmt.Errorf("field %s has wrong range", fieldName)
|
||||
val, err := parseProperty(fieldName, rangeOption, option)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nr, err := parseNumberRange(segs[1])
|
||||
nr, err := parseNumberRange(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -395,6 +426,15 @@ func parseOptions(val string) []string {
|
||||
return strings.Split(val, optionSeparator)
|
||||
}
|
||||
|
||||
func parseProperty(field, tag, val string) (string, error) {
|
||||
segs := strings.Split(val, equalToken)
|
||||
if len(segs) != 2 {
|
||||
return "", fmt.Errorf("field %s has wrong %s", field, tag)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(segs[1]), nil
|
||||
}
|
||||
|
||||
func parseSegments(val string) []string {
|
||||
var segments []string
|
||||
var escaped, grouped bool
|
||||
@@ -444,47 +484,6 @@ func parseSegments(val string) []string {
|
||||
return segments
|
||||
}
|
||||
|
||||
func reprOfValue(val reflect.Value) string {
|
||||
switch vt := val.Interface().(type) {
|
||||
case bool:
|
||||
return strconv.FormatBool(vt)
|
||||
case error:
|
||||
return vt.Error()
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(vt), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(vt, 'f', -1, 64)
|
||||
case fmt.Stringer:
|
||||
return vt.String()
|
||||
case int:
|
||||
return strconv.Itoa(vt)
|
||||
case int8:
|
||||
return strconv.Itoa(int(vt))
|
||||
case int16:
|
||||
return strconv.Itoa(int(vt))
|
||||
case int32:
|
||||
return strconv.Itoa(int(vt))
|
||||
case int64:
|
||||
return strconv.FormatInt(vt, 10)
|
||||
case string:
|
||||
return vt
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(vt), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(vt, 10)
|
||||
case []byte:
|
||||
return string(vt)
|
||||
default:
|
||||
return fmt.Sprint(val.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interface{}) error {
|
||||
switch kind {
|
||||
case reflect.Bool:
|
||||
@@ -504,13 +503,13 @@ func setMatchedPrimitiveValue(kind reflect.Kind, value reflect.Value, v interfac
|
||||
return nil
|
||||
}
|
||||
|
||||
func setValue(kind reflect.Kind, value reflect.Value, str string) error {
|
||||
func setValueFromString(kind reflect.Kind, value reflect.Value, str string) error {
|
||||
if !value.CanSet() {
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
value = ensureValue(value)
|
||||
v, err := convertType(kind, str)
|
||||
v, err := convertTypeFromString(kind, str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -583,7 +582,7 @@ func validateAndSetValue(kind reflect.Kind, value reflect.Value, str string, opt
|
||||
return errValueNotSettable
|
||||
}
|
||||
|
||||
v, err := convertType(kind, str)
|
||||
v, err := convertTypeFromString(kind, str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ func TestValidatePtrWithZeroValue(t *testing.T) {
|
||||
|
||||
func TestSetValueNotSettable(t *testing.T) {
|
||||
var i int
|
||||
assert.NotNil(t, setValue(reflect.Int, reflect.ValueOf(i), "1"))
|
||||
assert.NotNil(t, setValueFromString(reflect.Int, reflect.ValueOf(i), "1"))
|
||||
}
|
||||
|
||||
func TestParseKeyAndOptionsErrors(t *testing.T) {
|
||||
@@ -290,133 +290,9 @@ func TestSetValueFormatErrors(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.kind.String(), func(t *testing.T) {
|
||||
err := setValue(test.kind, test.target, test.value)
|
||||
err := setValueFromString(test.kind, test.target, test.value)
|
||||
assert.NotEqual(t, errValueNotSettable, err)
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepr(t *testing.T) {
|
||||
var (
|
||||
f32 float32 = 1.1
|
||||
f64 = 2.2
|
||||
i8 int8 = 1
|
||||
i16 int16 = 2
|
||||
i32 int32 = 3
|
||||
i64 int64 = 4
|
||||
u8 uint8 = 5
|
||||
u16 uint16 = 6
|
||||
u32 uint32 = 7
|
||||
u64 uint64 = 8
|
||||
)
|
||||
tests := []struct {
|
||||
v interface{}
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
{
|
||||
mockStringable{},
|
||||
"mocked",
|
||||
},
|
||||
{
|
||||
new(mockStringable),
|
||||
"mocked",
|
||||
},
|
||||
{
|
||||
newMockPtr(),
|
||||
"mockptr",
|
||||
},
|
||||
{
|
||||
&mockOpacity{
|
||||
val: 1,
|
||||
},
|
||||
"{1}",
|
||||
},
|
||||
{
|
||||
true,
|
||||
"true",
|
||||
},
|
||||
{
|
||||
false,
|
||||
"false",
|
||||
},
|
||||
{
|
||||
f32,
|
||||
"1.1",
|
||||
},
|
||||
{
|
||||
f64,
|
||||
"2.2",
|
||||
},
|
||||
{
|
||||
i8,
|
||||
"1",
|
||||
},
|
||||
{
|
||||
i16,
|
||||
"2",
|
||||
},
|
||||
{
|
||||
i32,
|
||||
"3",
|
||||
},
|
||||
{
|
||||
i64,
|
||||
"4",
|
||||
},
|
||||
{
|
||||
u8,
|
||||
"5",
|
||||
},
|
||||
{
|
||||
u16,
|
||||
"6",
|
||||
},
|
||||
{
|
||||
u32,
|
||||
"7",
|
||||
},
|
||||
{
|
||||
u64,
|
||||
"8",
|
||||
},
|
||||
{
|
||||
[]byte(`abcd`),
|
||||
"abcd",
|
||||
},
|
||||
{
|
||||
mockOpacity{val: 1},
|
||||
"{1}",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.expect, func(t *testing.T) {
|
||||
assert.Equal(t, test.expect, Repr(test.v))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockStringable struct{}
|
||||
|
||||
func (m mockStringable) String() string {
|
||||
return "mocked"
|
||||
}
|
||||
|
||||
type mockPtr struct{}
|
||||
|
||||
func newMockPtr() *mockPtr {
|
||||
return new(mockPtr)
|
||||
}
|
||||
|
||||
func (m *mockPtr) String() string {
|
||||
return "mockptr"
|
||||
}
|
||||
|
||||
type mockOpacity struct {
|
||||
val int
|
||||
}
|
||||
|
||||
@@ -7,12 +7,106 @@ type (
|
||||
Value(key string) (interface{}, bool)
|
||||
}
|
||||
|
||||
// A MapValuer is a map that can use Value method to get values with given keys.
|
||||
MapValuer map[string]interface{}
|
||||
// A valuerWithParent defines a node that has a parent node.
|
||||
valuerWithParent interface {
|
||||
Valuer
|
||||
// Parent get the parent valuer for current node.
|
||||
Parent() valuerWithParent
|
||||
}
|
||||
|
||||
// A node is a map that can use Value method to get values with given keys.
|
||||
node struct {
|
||||
current Valuer
|
||||
parent valuerWithParent
|
||||
}
|
||||
|
||||
// A valueWithParent is used to wrap the value with its parent.
|
||||
valueWithParent struct {
|
||||
value interface{}
|
||||
parent valuerWithParent
|
||||
}
|
||||
|
||||
// mapValuer is a type for map to meet the Valuer interface.
|
||||
mapValuer map[string]interface{}
|
||||
// simpleValuer is a type to get value from current node.
|
||||
simpleValuer node
|
||||
// recursiveValuer is a type to get the value recursively from current and parent nodes.
|
||||
recursiveValuer node
|
||||
)
|
||||
|
||||
// Value gets the value associated with the given key from mv.
|
||||
func (mv MapValuer) Value(key string) (interface{}, bool) {
|
||||
// Value gets the value assciated with the given key from mv.
|
||||
func (mv mapValuer) Value(key string) (interface{}, bool) {
|
||||
v, ok := mv[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Value gets the value associated with the given key from sv.
|
||||
func (sv simpleValuer) Value(key string) (interface{}, bool) {
|
||||
v, ok := sv.current.Value(key)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Parent get the parent valuer from sv.
|
||||
func (sv simpleValuer) Parent() valuerWithParent {
|
||||
if sv.parent == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return recursiveValuer{
|
||||
current: sv.parent,
|
||||
parent: sv.parent.Parent(),
|
||||
}
|
||||
}
|
||||
|
||||
// Value gets the value associated with the given key from rv,
|
||||
// and it will inherit the value from parent nodes.
|
||||
func (rv recursiveValuer) Value(key string) (interface{}, bool) {
|
||||
val, ok := rv.current.Value(key)
|
||||
if !ok {
|
||||
if parent := rv.Parent(); parent != nil {
|
||||
return parent.Value(key)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
vm, ok := val.(map[string]interface{})
|
||||
if !ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
parent := rv.Parent()
|
||||
if parent == nil {
|
||||
return val, true
|
||||
}
|
||||
|
||||
pv, ok := parent.Value(key)
|
||||
if !ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
pm, ok := pv.(map[string]interface{})
|
||||
if !ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
for k, v := range pm {
|
||||
if _, ok := vm[k]; !ok {
|
||||
vm[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return vm, true
|
||||
}
|
||||
|
||||
// Parent get the parent valuer from rv.
|
||||
func (rv recursiveValuer) Parent() valuerWithParent {
|
||||
if rv.parent == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return recursiveValuer{
|
||||
current: rv.parent,
|
||||
parent: rv.parent.Parent(),
|
||||
}
|
||||
}
|
||||
|
||||
57
core/mapping/valuer_test.go
Normal file
57
core/mapping/valuer_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package mapping
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMapValuerWithInherit_Value(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"discovery": map[string]interface{}{
|
||||
"host": "localhost",
|
||||
"port": 8080,
|
||||
},
|
||||
"component": map[string]interface{}{
|
||||
"name": "test",
|
||||
},
|
||||
}
|
||||
valuer := recursiveValuer{
|
||||
current: mapValuer(input["component"].(map[string]interface{})),
|
||||
parent: simpleValuer{
|
||||
current: mapValuer(input),
|
||||
},
|
||||
}
|
||||
|
||||
val, ok := valuer.Value("discovery")
|
||||
assert.True(t, ok)
|
||||
|
||||
m, ok := val.(map[string]interface{})
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "localhost", m["host"])
|
||||
assert.Equal(t, 8080, m["port"])
|
||||
}
|
||||
|
||||
func TestRecursiveValuer_Value(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"component": map[string]interface{}{
|
||||
"name": "test",
|
||||
"foo": map[string]interface{}{
|
||||
"bar": "baz",
|
||||
},
|
||||
},
|
||||
"foo": "value",
|
||||
}
|
||||
valuer := recursiveValuer{
|
||||
current: mapValuer(input["component"].(map[string]interface{})),
|
||||
parent: simpleValuer{
|
||||
current: mapValuer(input),
|
||||
},
|
||||
}
|
||||
|
||||
val, ok := valuer.Value("foo")
|
||||
assert.True(t, ok)
|
||||
assert.EqualValues(t, map[string]interface{}{
|
||||
"bar": "baz",
|
||||
}, val)
|
||||
}
|
||||
@@ -1,101 +1,27 @@
|
||||
package mapping
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// To make .json & .yaml consistent, we just use json as the tag key.
|
||||
const yamlTagKey = "json"
|
||||
|
||||
var (
|
||||
// ErrUnsupportedType is an error that indicates the config format is not supported.
|
||||
ErrUnsupportedType = errors.New("only map-like configs are supported")
|
||||
|
||||
yamlUnmarshaler = NewUnmarshaler(yamlTagKey)
|
||||
"github.com/zeromicro/go-zero/internal/encoding"
|
||||
)
|
||||
|
||||
// UnmarshalYamlBytes unmarshals content into v.
|
||||
func UnmarshalYamlBytes(content []byte, v interface{}) error {
|
||||
return unmarshalYamlBytes(content, v, yamlUnmarshaler)
|
||||
func UnmarshalYamlBytes(content []byte, v interface{}, opts ...UnmarshalOption) error {
|
||||
b, err := encoding.YamlToJson(content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return UnmarshalJsonBytes(b, v, opts...)
|
||||
}
|
||||
|
||||
// UnmarshalYamlReader unmarshals content from reader into v.
|
||||
func UnmarshalYamlReader(reader io.Reader, v interface{}) error {
|
||||
return unmarshalYamlReader(reader, v, yamlUnmarshaler)
|
||||
}
|
||||
|
||||
func cleanupInterfaceMap(in map[interface{}]interface{}) map[string]interface{} {
|
||||
res := make(map[string]interface{})
|
||||
for k, v := range in {
|
||||
res[Repr(k)] = cleanupMapValue(v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func cleanupInterfaceNumber(in interface{}) json.Number {
|
||||
return json.Number(Repr(in))
|
||||
}
|
||||
|
||||
func cleanupInterfaceSlice(in []interface{}) []interface{} {
|
||||
res := make([]interface{}, len(in))
|
||||
for i, v := range in {
|
||||
res[i] = cleanupMapValue(v)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func cleanupMapValue(v interface{}) interface{} {
|
||||
switch v := v.(type) {
|
||||
case []interface{}:
|
||||
return cleanupInterfaceSlice(v)
|
||||
case map[interface{}]interface{}:
|
||||
return cleanupInterfaceMap(v)
|
||||
case bool, string:
|
||||
return v
|
||||
case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64, float32, float64:
|
||||
return cleanupInterfaceNumber(v)
|
||||
default:
|
||||
return Repr(v)
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshal(unmarshaler *Unmarshaler, o, v interface{}) error {
|
||||
if m, ok := o.(map[string]interface{}); ok {
|
||||
return unmarshaler.Unmarshal(m, v)
|
||||
}
|
||||
|
||||
return ErrUnsupportedType
|
||||
}
|
||||
|
||||
func unmarshalYamlBytes(content []byte, v interface{}, unmarshaler *Unmarshaler) error {
|
||||
var o interface{}
|
||||
if err := yamlUnmarshal(content, &o); err != nil {
|
||||
func UnmarshalYamlReader(reader io.Reader, v interface{}, opts ...UnmarshalOption) error {
|
||||
b, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unmarshal(unmarshaler, o, v)
|
||||
}
|
||||
|
||||
func unmarshalYamlReader(reader io.Reader, v interface{}, unmarshaler *Unmarshaler) error {
|
||||
var res interface{}
|
||||
if err := yaml.NewDecoder(reader).Decode(&res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return unmarshal(unmarshaler, cleanupMapValue(res), v)
|
||||
}
|
||||
|
||||
// yamlUnmarshal YAML to map[string]interface{} instead of map[interface{}]interface{}.
|
||||
func yamlUnmarshal(in []byte, out interface{}) error {
|
||||
var res interface{}
|
||||
if err := yaml.Unmarshal(in, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*out.(*interface{}) = cleanupMapValue(res)
|
||||
return nil
|
||||
return UnmarshalYamlBytes(b, v, opts...)
|
||||
}
|
||||
|
||||
@@ -934,9 +934,8 @@ func TestUnmarshalYamlReaderError(t *testing.T) {
|
||||
err := UnmarshalYamlReader(reader, &v)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
reader = strings.NewReader("chenquan")
|
||||
err = UnmarshalYamlReader(reader, &v)
|
||||
assert.ErrorIs(t, err, ErrUnsupportedType)
|
||||
reader = strings.NewReader("foo")
|
||||
assert.Error(t, UnmarshalYamlReader(reader, &v))
|
||||
}
|
||||
|
||||
func TestUnmarshalYamlBadReader(t *testing.T) {
|
||||
@@ -1012,6 +1011,13 @@ func TestUnmarshalYamlMapRune(t *testing.T) {
|
||||
assert.Equal(t, rune(3), v.Machine["node3"])
|
||||
}
|
||||
|
||||
func TestUnmarshalYamlBadInput(t *testing.T) {
|
||||
var v struct {
|
||||
Any string
|
||||
}
|
||||
assert.Error(t, UnmarshalYamlBytes([]byte("':foo"), &v))
|
||||
}
|
||||
|
||||
type badReader struct{}
|
||||
|
||||
func (b *badReader) Read(_ []byte) (n int, err error) {
|
||||
|
||||
@@ -6,14 +6,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Unstable is used to generate random value around the mean value base on given deviation.
|
||||
// An Unstable is used to generate random value around the mean value base on given deviation.
|
||||
type Unstable struct {
|
||||
deviation float64
|
||||
r *rand.Rand
|
||||
lock *sync.Mutex
|
||||
}
|
||||
|
||||
// NewUnstable returns a Unstable.
|
||||
// NewUnstable returns an Unstable.
|
||||
func NewUnstable(deviation float64) Unstable {
|
||||
if deviation < 0 {
|
||||
deviation = 0
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
"github.com/zeromicro/go-zero/core/prometheus"
|
||||
)
|
||||
|
||||
@@ -17,6 +18,9 @@ func TestNewCounterVec(t *testing.T) {
|
||||
})
|
||||
defer counterVec.close()
|
||||
counterVecNil := NewCounterVec(nil)
|
||||
counterVec.Inc("path", "code")
|
||||
counterVec.Add(1, "path", "code")
|
||||
proc.Shutdown()
|
||||
assert.NotNil(t, counterVec)
|
||||
assert.Nil(t, counterVecNil)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNewGaugeVec(t *testing.T) {
|
||||
@@ -18,6 +19,8 @@ func TestNewGaugeVec(t *testing.T) {
|
||||
gaugeVecNil := NewGaugeVec(nil)
|
||||
assert.NotNil(t, gaugeVec)
|
||||
assert.Nil(t, gaugeVecNil)
|
||||
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
func TestGaugeInc(t *testing.T) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNewHistogramVec(t *testing.T) {
|
||||
@@ -47,4 +48,6 @@ func TestHistogramObserve(t *testing.T) {
|
||||
|
||||
err := testutil.CollectAndCompare(hv.histogram, strings.NewReader(metadata+val))
|
||||
assert.Nil(t, err)
|
||||
|
||||
proc.Shutdown()
|
||||
}
|
||||
|
||||
@@ -15,5 +15,14 @@ func AddWrapUpListener(fn func()) func() {
|
||||
return fn
|
||||
}
|
||||
|
||||
// SetTimeToForceQuit does nothing on windows.
|
||||
func SetTimeToForceQuit(duration time.Duration) {
|
||||
}
|
||||
|
||||
// Shutdown does nothing on windows.
|
||||
func Shutdown() {
|
||||
}
|
||||
|
||||
// WrapUp does nothing on windows.
|
||||
func WrapUp() {
|
||||
}
|
||||
|
||||
@@ -43,6 +43,16 @@ func SetTimeToForceQuit(duration time.Duration) {
|
||||
delayTimeBeforeForceQuit = duration
|
||||
}
|
||||
|
||||
// Shutdown calls the registered shutdown listeners, only for test purpose.
|
||||
func Shutdown() {
|
||||
shutdownListeners.notifyListeners()
|
||||
}
|
||||
|
||||
// WrapUp wraps up the process, only for test purpose.
|
||||
func WrapUp() {
|
||||
wrapUpListeners.notifyListeners()
|
||||
}
|
||||
|
||||
func gracefulStop(signals chan os.Signal) {
|
||||
signal.Stop(signals)
|
||||
|
||||
|
||||
@@ -18,14 +18,14 @@ func TestShutdown(t *testing.T) {
|
||||
called := AddWrapUpListener(func() {
|
||||
val++
|
||||
})
|
||||
wrapUpListeners.notifyListeners()
|
||||
WrapUp()
|
||||
called()
|
||||
assert.Equal(t, 1, val)
|
||||
|
||||
called = AddShutdownListener(func() {
|
||||
val += 2
|
||||
})
|
||||
shutdownListeners.notifyListeners()
|
||||
Shutdown()
|
||||
called()
|
||||
assert.Equal(t, 3, val)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ import "github.com/zeromicro/go-zero/core/logx"
|
||||
|
||||
// Recover is used with defer to do cleanup on panics.
|
||||
// Use it like:
|
||||
// defer Recover(func() {})
|
||||
//
|
||||
// defer Recover(func() {})
|
||||
func Recover(cleanups ...func()) {
|
||||
for _, cleanup := range cleanups {
|
||||
cleanup()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/zeromicro/go-zero/core/prometheus"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"github.com/zeromicro/go-zero/internal/devserver"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,10 +29,12 @@ const (
|
||||
type ServiceConf struct {
|
||||
Name string
|
||||
Log logx.LogConf
|
||||
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"`
|
||||
MetricsUrl string `json:",optional"`
|
||||
Mode string `json:",default=pro,options=dev|test|rt|pre|pro"`
|
||||
MetricsUrl string `json:",optional"`
|
||||
// Deprecated: please use DevServer
|
||||
Prometheus prometheus.Config `json:",optional"`
|
||||
Telemetry trace.Config `json:",optional"`
|
||||
DevServer devserver.Config `json:",optional"`
|
||||
}
|
||||
|
||||
// MustSetUp sets up the service, exits on error.
|
||||
@@ -64,6 +67,7 @@ func (sc ServiceConf) SetUp() error {
|
||||
if len(sc.MetricsUrl) > 0 {
|
||||
stat.SetReportWriter(stat.NewRemoteWriter(sc.MetricsUrl))
|
||||
}
|
||||
devserver.StartAgent(sc.DevServer)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
@@ -16,3 +17,15 @@ func TestServiceConf(t *testing.T) {
|
||||
}
|
||||
c.MustSetUp()
|
||||
}
|
||||
|
||||
func TestServiceConfWithMetricsUrl(t *testing.T) {
|
||||
c := ServiceConf{
|
||||
Name: "foo",
|
||||
Log: logx.LogConf{
|
||||
Mode: "volume",
|
||||
},
|
||||
Mode: "dev",
|
||||
MetricsUrl: "http://localhost:8080",
|
||||
}
|
||||
assert.NoError(t, c.SetUp())
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,6 +56,7 @@ func TestServiceGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
group.Stop()
|
||||
proc.Shutdown()
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package stat
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -12,6 +13,9 @@ import (
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
os.Setenv(clusterNameKey, "test-cluster")
|
||||
defer os.Unsetenv(clusterNameKey)
|
||||
|
||||
var count int32
|
||||
SetReporter(func(s string) {
|
||||
atomic.AddInt32(&count, 1)
|
||||
|
||||
@@ -258,7 +258,7 @@ func parseUints(val string) ([]uint64, error) {
|
||||
return sets, nil
|
||||
}
|
||||
|
||||
// runningInUserNS detects whether we are currently running in a user namespace.
|
||||
// runningInUserNS detects whether we are currently running in an user namespace.
|
||||
func runningInUserNS() bool {
|
||||
nsOnce.Do(func() {
|
||||
file, err := os.Open("/proc/self/uid_map")
|
||||
|
||||
@@ -33,13 +33,7 @@ func initialize() {
|
||||
}
|
||||
|
||||
cores = uint64(len(cpus))
|
||||
sets, err := cpuSets()
|
||||
if err != nil {
|
||||
logx.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
quota = float64(len(sets))
|
||||
quota = float64(len(cpus))
|
||||
cq, err := cpuQuota()
|
||||
if err == nil {
|
||||
if cq != -1 {
|
||||
|
||||
@@ -45,9 +45,13 @@ func RawFieldNames(in interface{}, postgresSql ...bool) []string {
|
||||
// `db:"id"`
|
||||
// `db:"id,type=char,length=16"`
|
||||
// `db:",type=char,length=16"`
|
||||
// `db:"-,type=char,length=16"`
|
||||
if strings.Contains(tagv, ",") {
|
||||
tagv = strings.TrimSpace(strings.Split(tagv, ",")[0])
|
||||
}
|
||||
if tagv == "-" {
|
||||
continue
|
||||
}
|
||||
if len(tagv) == 0 {
|
||||
tagv = fi.Name
|
||||
}
|
||||
|
||||
@@ -39,3 +39,33 @@ func TestFieldNamesWithTagOptions(t *testing.T) {
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedUserWithDashTag struct {
|
||||
ID string `db:"id" json:"id,omitempty"`
|
||||
UserName string `db:"user_name" json:"userName,omitempty"`
|
||||
Mobile string `db:"-" json:"mobile,omitempty"`
|
||||
}
|
||||
|
||||
func TestFieldNamesWithDashTag(t *testing.T) {
|
||||
t.Run("new", func(t *testing.T) {
|
||||
var u mockedUserWithDashTag
|
||||
out := RawFieldNames(&u)
|
||||
expected := []string{"`id`", "`user_name`"}
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
type mockedUserWithDashTagAndOptions struct {
|
||||
ID string `db:"id" json:"id,omitempty"`
|
||||
UserName string `db:"user_name,type=varchar,length=255" json:"userName,omitempty"`
|
||||
Mobile string `db:"-,type=varchar,length=255" json:"mobile,omitempty"`
|
||||
}
|
||||
|
||||
func TestFieldNamesWithDashTagAndOptions(t *testing.T) {
|
||||
t.Run("new", func(t *testing.T) {
|
||||
var u mockedUserWithDashTagAndOptions
|
||||
out := RawFieldNames(&u)
|
||||
expected := []string{"`id`", "`user_name`"}
|
||||
assert.Equal(t, expected, out)
|
||||
})
|
||||
}
|
||||
|
||||
5
core/stores/cache/cache.go
vendored
5
core/stores/cache/cache.go
vendored
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
@@ -62,12 +63,12 @@ func New(c ClusterConf, barrier syncx.SingleFlight, st *Stat, errNotFound error,
|
||||
}
|
||||
|
||||
if len(c) == 1 {
|
||||
return NewNode(c[0].NewRedis(), barrier, st, errNotFound, opts...)
|
||||
return NewNode(redis.MustNewRedis(c[0].RedisConf), barrier, st, errNotFound, opts...)
|
||||
}
|
||||
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := NewNode(node.NewRedis(), barrier, st, errNotFound, opts...)
|
||||
cn := NewNode(redis.MustNewRedis(node.RedisConf), barrier, st, errNotFound, opts...)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
|
||||
119
core/stores/cache/cache_test.go
vendored
119
core/stores/cache/cache_test.go
vendored
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/errorx"
|
||||
"github.com/zeromicro/go-zero/core/hash"
|
||||
@@ -109,51 +110,85 @@ func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val interface{}, ke
|
||||
}
|
||||
|
||||
func TestCache_SetDel(t *testing.T) {
|
||||
const total = 1000
|
||||
r1, clean1, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean1()
|
||||
r2, clean2, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean2()
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr,
|
||||
Type: redis.NodeType,
|
||||
t.Run("test set del", func(t *testing.T) {
|
||||
const total = 1000
|
||||
r1, clean1, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean1()
|
||||
r2, clean2, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
defer clean2()
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr,
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr,
|
||||
Type: redis.NodeType,
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr,
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
|
||||
} else {
|
||||
assert.Nil(t, c.SetWithExpire(fmt.Sprintf("key/%d", i), i, 0))
|
||||
}
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
|
||||
assert.Equal(t, i, val)
|
||||
}
|
||||
assert.Nil(t, c.Del())
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
for i := 0; i < total; i++ {
|
||||
if i%2 == 0 {
|
||||
assert.Nil(t, c.Set(fmt.Sprintf("key/%d", i), i))
|
||||
} else {
|
||||
assert.Nil(t, c.SetWithExpire(fmt.Sprintf("key/%d", i), i, 0))
|
||||
}
|
||||
}
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.Nil(t, c.Get(fmt.Sprintf("key/%d", i), &val))
|
||||
assert.Equal(t, i, val)
|
||||
}
|
||||
assert.Nil(t, c.Del())
|
||||
for i := 0; i < total; i++ {
|
||||
assert.Nil(t, c.Del(fmt.Sprintf("key/%d", i)))
|
||||
}
|
||||
assert.Nil(t, c.Del("a", "b", "c"))
|
||||
for i := 0; i < total; i++ {
|
||||
var val int
|
||||
assert.True(t, c.IsNotFound(c.Get(fmt.Sprintf("key/%d", i), &val)))
|
||||
assert.Equal(t, 0, val)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test set del error", func(t *testing.T) {
|
||||
r1, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r1.Close()
|
||||
|
||||
r2, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r2.Close()
|
||||
|
||||
conf := ClusterConf{
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r1.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
{
|
||||
RedisConf: redis.RedisConf{
|
||||
Host: r2.Addr(),
|
||||
Type: redis.NodeType,
|
||||
},
|
||||
Weight: 100,
|
||||
},
|
||||
}
|
||||
c := New(conf, syncx.NewSingleFlight(), NewStat("mock"), errPlaceholder)
|
||||
r1.SetError("mock error")
|
||||
r2.SetError("mock error")
|
||||
assert.NoError(t, c.Del("a", "b", "c"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestCache_OneNode(t *testing.T) {
|
||||
|
||||
3
core/stores/cache/cachenode.go
vendored
3
core/stores/cache/cachenode.go
vendored
@@ -277,5 +277,6 @@ func (c cacheNode) processCache(ctx context.Context, key, data string, v interfa
|
||||
|
||||
func (c cacheNode) setCacheWithNotFound(ctx context.Context, key string) error {
|
||||
seconds := int(math.Ceil(c.aroundDuration(c.notFoundExpiry).Seconds()))
|
||||
return c.rds.SetexCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||
_, err := c.rds.SetnxExCtx(ctx, key, notFoundPlaceholder, seconds)
|
||||
return err
|
||||
}
|
||||
|
||||
114
core/stores/cache/cachenode_test.go
vendored
114
core/stores/cache/cachenode_test.go
vendored
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -11,12 +12,14 @@ import (
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/collection"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/mathx"
|
||||
"github.com/zeromicro/go-zero/core/stat"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis/redistest"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
var errTestNotFound = errors.New("not found")
|
||||
@@ -27,27 +30,54 @@ func init() {
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCache(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
defer clean()
|
||||
t.Run("del cache", func(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
store.Type = redis.ClusterType
|
||||
defer clean()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
assert.Nil(t, cn.Del())
|
||||
assert.Nil(t, cn.Del([]string{}...))
|
||||
assert.Nil(t, cn.Del(make([]string, 0)...))
|
||||
cn.Set("first", "one")
|
||||
assert.Nil(t, cn.Del("first"))
|
||||
cn.Set("first", "one")
|
||||
cn.Set("second", "two")
|
||||
assert.Nil(t, cn.Del("first", "second"))
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
assert.Nil(t, cn.Del())
|
||||
assert.Nil(t, cn.Del([]string{}...))
|
||||
assert.Nil(t, cn.Del(make([]string, 0)...))
|
||||
cn.Set("first", "one")
|
||||
assert.Nil(t, cn.Del("first"))
|
||||
cn.Set("first", "one")
|
||||
cn.Set("second", "two")
|
||||
assert.Nil(t, cn.Del("first", "second"))
|
||||
})
|
||||
|
||||
t.Run("del cache with errors", func(t *testing.T) {
|
||||
old := timingWheel
|
||||
ticker := timex.NewFakeTicker()
|
||||
var err error
|
||||
timingWheel, err = collection.NewTimingWheelWithTicker(
|
||||
time.Millisecond, timingWheelSlots, func(key, value interface{}) {
|
||||
clean(key, value)
|
||||
}, ticker)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
timingWheel = old
|
||||
})
|
||||
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
r.SetError("mock error")
|
||||
|
||||
node := NewNode(redis.New(r.Addr(), redis.Cluster()), syncx.NewSingleFlight(),
|
||||
NewStat("any"), errTestNotFound)
|
||||
assert.NoError(t, node.Del("foo", "bar"))
|
||||
ticker.Tick()
|
||||
runtime.Gosched()
|
||||
})
|
||||
}
|
||||
|
||||
func TestCacheNode_DelCacheWithErrors(t *testing.T) {
|
||||
@@ -125,6 +155,21 @@ func TestCacheNode_Take(t *testing.T) {
|
||||
assert.Equal(t, `"value"`, val)
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeBadRedis(t *testing.T) {
|
||||
r, err := miniredis.Run()
|
||||
assert.NoError(t, err)
|
||||
defer r.Close()
|
||||
r.SetError("mock error")
|
||||
|
||||
cn := NewNode(redis.New(r.Addr()), syncx.NewSingleFlight(), NewStat("any"),
|
||||
errTestNotFound, WithExpiry(time.Second), WithNotFoundExpiry(time.Second))
|
||||
var str string
|
||||
assert.Error(t, cn.Take(&str, "any", func(v interface{}) error {
|
||||
*v.(*string) = "value"
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
@@ -164,6 +209,35 @@ func TestCacheNode_TakeNotFound(t *testing.T) {
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.NoError(t, err)
|
||||
defer clean()
|
||||
|
||||
cn := cacheNode{
|
||||
rds: store,
|
||||
r: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
barrier: syncx.NewSingleFlight(),
|
||||
lock: new(sync.Mutex),
|
||||
unstableExpiry: mathx.NewUnstable(expiryDeviation),
|
||||
stat: NewStat("any"),
|
||||
errNotFound: errTestNotFound,
|
||||
}
|
||||
|
||||
var str string
|
||||
err = cn.Take(&str, "any", func(v interface{}) error {
|
||||
store.Set("any", "foo")
|
||||
return errTestNotFound
|
||||
})
|
||||
assert.True(t, cn.IsNotFound(err))
|
||||
|
||||
val, err := store.Get("any")
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, "foo", val)
|
||||
}
|
||||
assert.True(t, cn.IsNotFound(cn.Get("any", &str)))
|
||||
}
|
||||
|
||||
func TestCacheNode_TakeWithExpire(t *testing.T) {
|
||||
store, clean, err := redistest.CreateRedis()
|
||||
assert.Nil(t, err)
|
||||
|
||||
4
core/stores/cache/cacheopt.go
vendored
4
core/stores/cache/cacheopt.go
vendored
@@ -34,14 +34,14 @@ func newOptions(opts ...Option) Options {
|
||||
return o
|
||||
}
|
||||
|
||||
// WithExpiry returns a func to customize a Options with given expiry.
|
||||
// WithExpiry returns a func to customize an Options with given expiry.
|
||||
func WithExpiry(expiry time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.Expiry = expiry
|
||||
}
|
||||
}
|
||||
|
||||
// WithNotFoundExpiry returns a func to customize a Options with given not found expiry.
|
||||
// WithNotFoundExpiry returns a func to customize an Options with given not found expiry.
|
||||
func WithNotFoundExpiry(expiry time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.NotFoundExpiry = expiry
|
||||
|
||||
16
core/stores/cache/cachestat.go
vendored
16
core/stores/cache/cachestat.go
vendored
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
const statInterval = time.Minute
|
||||
@@ -25,7 +26,13 @@ func NewStat(name string) *Stat {
|
||||
ret := &Stat{
|
||||
name: name,
|
||||
}
|
||||
go ret.statLoop()
|
||||
|
||||
go func() {
|
||||
ticker := timex.NewTicker(statInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
ret.statLoop(ticker)
|
||||
}()
|
||||
|
||||
return ret
|
||||
}
|
||||
@@ -50,11 +57,8 @@ func (s *Stat) IncrementDbFails() {
|
||||
atomic.AddUint64(&s.DbFails, 1)
|
||||
}
|
||||
|
||||
func (s *Stat) statLoop() {
|
||||
ticker := time.NewTicker(statInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
func (s *Stat) statLoop(ticker timex.Ticker) {
|
||||
for range ticker.Chan() {
|
||||
total := atomic.SwapUint64(&s.Total, 0)
|
||||
if total == 0 {
|
||||
continue
|
||||
|
||||
28
core/stores/cache/cachestat_test.go
vendored
Normal file
28
core/stores/cache/cachestat_test.go
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
func TestCacheStat_statLoop(t *testing.T) {
|
||||
t.Run("stat loop total 0", func(t *testing.T) {
|
||||
var stat Stat
|
||||
ticker := timex.NewFakeTicker()
|
||||
go stat.statLoop(ticker)
|
||||
ticker.Tick()
|
||||
ticker.Tick()
|
||||
ticker.Stop()
|
||||
})
|
||||
|
||||
t.Run("stat loop total not 0", func(t *testing.T) {
|
||||
var stat Stat
|
||||
stat.IncrementTotal()
|
||||
ticker := timex.NewFakeTicker()
|
||||
go stat.statLoop(ticker)
|
||||
ticker.Tick()
|
||||
ticker.Tick()
|
||||
ticker.Stop()
|
||||
})
|
||||
}
|
||||
2
core/stores/cache/cleaner_test.go
vendored
2
core/stores/cache/cleaner_test.go
vendored
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/proc"
|
||||
)
|
||||
|
||||
func TestNextDelay(t *testing.T) {
|
||||
@@ -51,6 +52,7 @@ func TestNextDelay(t *testing.T) {
|
||||
next, ok := nextDelay(test.input)
|
||||
assert.Equal(t, test.ok, ok)
|
||||
assert.Equal(t, test.output, next)
|
||||
proc.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ func NewStore(c KvConf) Store {
|
||||
// because Store and redis.Redis has different methods.
|
||||
dispatcher := hash.NewConsistentHash()
|
||||
for _, node := range c {
|
||||
cn := node.NewRedis()
|
||||
cn := redis.MustNewRedis(node.RedisConf)
|
||||
dispatcher.AddWithWeight(cn, node.Weight)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,9 +26,14 @@ type (
|
||||
)
|
||||
|
||||
// NewBulkInserter returns a BulkInserter.
|
||||
func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkInserter {
|
||||
func NewBulkInserter(coll Collection, interval ...time.Duration) (*BulkInserter, error) {
|
||||
cloneColl, err := coll.Clone()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inserter := &dbInserter{
|
||||
collection: coll,
|
||||
collection: cloneColl,
|
||||
}
|
||||
|
||||
duration := flushInterval
|
||||
@@ -39,7 +44,7 @@ func NewBulkInserter(coll *mongo.Collection, interval ...time.Duration) *BulkIns
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(duration, inserter),
|
||||
inserter: inserter,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Flush flushes the inserter, writes all pending records.
|
||||
|
||||
@@ -15,7 +15,8 @@ func TestBulkInserter(t *testing.T) {
|
||||
|
||||
mt.Run("test", func(mt *mtest.T) {
|
||||
mt.AddMockResponses(mtest.CreateSuccessResponse(bson.D{{Key: "ok", Value: 1}}...))
|
||||
bulk := NewBulkInserter(mt.Coll)
|
||||
bulk, err := NewBulkInserter(createModel(mt).Collection)
|
||||
assert.Equal(t, err, nil)
|
||||
bulk.SetResultHandler(func(result *mongo.InsertManyResult, err error) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 2, len(result.InsertedIDs))
|
||||
|
||||
@@ -3,15 +3,12 @@ package mon
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second
|
||||
|
||||
var clientManager = syncx.NewResourceManager()
|
||||
|
||||
// ClosableClient wraps *mongo.Client and provides a Close method.
|
||||
@@ -30,9 +27,20 @@ func Inject(key string, client *mongo.Client) {
|
||||
clientManager.Inject(key, &ClosableClient{client})
|
||||
}
|
||||
|
||||
func getClient(url string) (*mongo.Client, error) {
|
||||
func getClient(url string, opts ...Option) (*mongo.Client, error) {
|
||||
val, err := clientManager.GetResource(url, func() (io.Closer, error) {
|
||||
cli, err := mongo.Connect(context.Background(), mopt.Client().ApplyURI(url))
|
||||
o := mopt.Client().ApplyURI(url)
|
||||
opts = append([]Option{defaultTimeoutOption()}, opts...)
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
cli, err := mongo.Connect(context.Background(), o)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = cli.Ping(context.Background(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func MustNewModel(uri, db, collection string, opts ...Option) *Model {
|
||||
|
||||
// NewModel returns a Model.
|
||||
func NewModel(uri, db, collection string, opts ...Option) (*Model, error) {
|
||||
cli, err := getClient(uri)
|
||||
cli, err := getClient(uri, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,14 +4,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
const defaultTimeout = time.Second * 3
|
||||
|
||||
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
|
||||
|
||||
type (
|
||||
options struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
options = mopt.ClientOptions
|
||||
|
||||
// Option defines the method to customize a mongo model.
|
||||
Option func(opts *options)
|
||||
@@ -22,8 +23,15 @@ func SetSlowThreshold(threshold time.Duration) {
|
||||
slowThreshold.Set(threshold)
|
||||
}
|
||||
|
||||
func defaultOptions() *options {
|
||||
return &options{
|
||||
timeout: defaultTimeout,
|
||||
func defaultTimeoutOption() Option {
|
||||
return func(opts *options) {
|
||||
opts.SetTimeout(defaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout set the mon client operation timeout.
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
mopt "go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
func TestSetSlowThreshold(t *testing.T) {
|
||||
@@ -13,6 +14,14 @@ func TestSetSlowThreshold(t *testing.T) {
|
||||
assert.Equal(t, time.Second, slowThreshold.Load())
|
||||
}
|
||||
|
||||
func TestDefaultOptions(t *testing.T) {
|
||||
assert.Equal(t, defaultTimeout, defaultOptions().timeout)
|
||||
func Test_defaultTimeoutOption(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
defaultTimeoutOption()(opts)
|
||||
assert.Equal(t, defaultTimeout, *opts.Timeout)
|
||||
}
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
opts := mopt.Client()
|
||||
WithTimeout(time.Second)(opts)
|
||||
assert.Equal(t, time.Second, *opts.Timeout)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/zeromicro/go-zero/core/trace"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
@@ -14,12 +13,10 @@ import (
|
||||
var mongoCmdAttributeKey = attribute.Key("mongo.cmd")
|
||||
|
||||
func startSpan(ctx context.Context, cmd string) (context.Context, oteltrace.Span) {
|
||||
tracer := otel.GetTracerProvider().Tracer(trace.TraceName)
|
||||
ctx, span := tracer.Start(ctx,
|
||||
spanName,
|
||||
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
|
||||
)
|
||||
tracer := trace.TracerFromContext(ctx)
|
||||
ctx, span := tracer.Start(ctx, spanName, oteltrace.WithSpanKind(oteltrace.SpanKindClient))
|
||||
span.SetAttributes(mongoCmdAttributeKey.String(cmd))
|
||||
|
||||
return ctx, span
|
||||
}
|
||||
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/executors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
const (
|
||||
flushInterval = time.Second
|
||||
maxBulkRows = 1000
|
||||
)
|
||||
|
||||
type (
|
||||
// ResultHandler is a handler that used to handle results.
|
||||
ResultHandler func(*mgo.BulkResult, error)
|
||||
|
||||
// A BulkInserter is used to insert bulk of mongo records.
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
}
|
||||
)
|
||||
|
||||
// NewBulkInserter returns a BulkInserter.
|
||||
func NewBulkInserter(session *mgo.Session, dbName string, collectionNamer func() string) *BulkInserter {
|
||||
inserter := &dbInserter{
|
||||
session: session,
|
||||
dbName: dbName,
|
||||
collectionNamer: collectionNamer,
|
||||
}
|
||||
|
||||
return &BulkInserter{
|
||||
executor: executors.NewPeriodicalExecutor(flushInterval, inserter),
|
||||
inserter: inserter,
|
||||
}
|
||||
}
|
||||
|
||||
// Flush flushes the inserter, writes all pending records.
|
||||
func (bi *BulkInserter) Flush() {
|
||||
bi.executor.Flush()
|
||||
}
|
||||
|
||||
// Insert inserts doc.
|
||||
func (bi *BulkInserter) Insert(doc interface{}) {
|
||||
bi.executor.Add(doc)
|
||||
}
|
||||
|
||||
// SetResultHandler sets the result handler.
|
||||
func (bi *BulkInserter) SetResultHandler(handler ResultHandler) {
|
||||
bi.executor.Sync(func() {
|
||||
bi.inserter.resultHandler = handler
|
||||
})
|
||||
}
|
||||
|
||||
type dbInserter struct {
|
||||
session *mgo.Session
|
||||
dbName string
|
||||
collectionNamer func() string
|
||||
documents []interface{}
|
||||
resultHandler ResultHandler
|
||||
}
|
||||
|
||||
func (in *dbInserter) AddTask(doc interface{}) bool {
|
||||
in.documents = append(in.documents, doc)
|
||||
return len(in.documents) >= maxBulkRows
|
||||
}
|
||||
|
||||
func (in *dbInserter) Execute(objs interface{}) {
|
||||
docs := objs.([]interface{})
|
||||
if len(docs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
bulk := in.session.DB(in.dbName).C(in.collectionNamer()).Bulk()
|
||||
bulk.Insert(docs...)
|
||||
bulk.Unordered()
|
||||
result, err := bulk.Run()
|
||||
if in.resultHandler != nil {
|
||||
in.resultHandler(result, err)
|
||||
} else if err != nil {
|
||||
logx.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (in *dbInserter) RemoveAll() interface{} {
|
||||
documents := in.documents
|
||||
in.documents = nil
|
||||
return documents
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo/internal"
|
||||
"github.com/zeromicro/go-zero/core/timex"
|
||||
)
|
||||
|
||||
const defaultSlowThreshold = time.Millisecond * 500
|
||||
|
||||
// ErrNotFound is an alias of mgo.ErrNotFound.
|
||||
var ErrNotFound = mgo.ErrNotFound
|
||||
|
||||
type (
|
||||
// Collection interface represents a mongo connection.
|
||||
Collection interface {
|
||||
Find(query interface{}) Query
|
||||
FindId(id interface{}) Query
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) Pipe
|
||||
Remove(selector interface{}) error
|
||||
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
|
||||
RemoveId(id interface{}) error
|
||||
Update(selector, update interface{}) error
|
||||
UpdateId(id, update interface{}) error
|
||||
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
|
||||
}
|
||||
|
||||
decoratedCollection struct {
|
||||
name string
|
||||
collection internal.MgoCollection
|
||||
brk breaker.Breaker
|
||||
}
|
||||
|
||||
keepablePromise struct {
|
||||
promise breaker.Promise
|
||||
log func(error)
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
|
||||
return &decoratedCollection{
|
||||
name: collection.FullName,
|
||||
collection: collection,
|
||||
brk: brk,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Find(query interface{}) Query {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedQuery{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedQuery{
|
||||
Query: c.collection.Find(query),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("find", duration, err, query)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) FindId(id interface{}) Query {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedQuery{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedQuery{
|
||||
Query: c.collection.FindId(id),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("findId", duration, err, id)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Insert(docs ...interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("insert", duration, err, docs...)
|
||||
}()
|
||||
|
||||
return c.collection.Insert(docs...)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Pipe(pipeline interface{}) Pipe {
|
||||
promise, err := c.brk.Allow()
|
||||
if err != nil {
|
||||
return rejectedPipe{}
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
return promisedPipe{
|
||||
Pipe: c.collection.Pipe(pipeline),
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(err error) {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("pipe", duration, err, pipeline)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Remove(selector interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("remove", duration, err, selector)
|
||||
}()
|
||||
|
||||
return c.collection.Remove(selector)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) RemoveAll(selector interface{}) (info *mgo.ChangeInfo, err error) {
|
||||
err = c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("removeAll", duration, err, selector)
|
||||
}()
|
||||
|
||||
info, err = c.collection.RemoveAll(selector)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) RemoveId(id interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("removeId", duration, err, id)
|
||||
}()
|
||||
|
||||
return c.collection.RemoveId(id)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Update(selector, update interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("update", duration, err, selector, update)
|
||||
}()
|
||||
|
||||
return c.collection.Update(selector, update)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) UpdateId(id, update interface{}) (err error) {
|
||||
return c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("updateId", duration, err, id, update)
|
||||
}()
|
||||
|
||||
return c.collection.UpdateId(id, update)
|
||||
}, acceptable)
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) Upsert(selector, update interface{}) (info *mgo.ChangeInfo, err error) {
|
||||
err = c.brk.DoWithAcceptable(func() error {
|
||||
startTime := timex.Now()
|
||||
defer func() {
|
||||
duration := timex.Since(startTime)
|
||||
c.logDuration("upsert", duration, err, selector, update)
|
||||
}()
|
||||
|
||||
info, err = c.collection.Upsert(selector, update)
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *decoratedCollection) logDuration(method string, duration time.Duration, err error, docs ...interface{}) {
|
||||
content, e := json.Marshal(docs)
|
||||
if e != nil {
|
||||
logx.Error(err)
|
||||
} else if err != nil {
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - fail(%s) - %s",
|
||||
c.name, method, err.Error(), string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("mongo(%s) - %s - fail(%s) - %s",
|
||||
c.name, method, err.Error(), string(content))
|
||||
}
|
||||
} else {
|
||||
if duration > slowThreshold.Load() {
|
||||
logx.WithDuration(duration).Slowf("[MONGO] mongo(%s) - slowcall - %s - ok - %s",
|
||||
c.name, method, string(content))
|
||||
} else {
|
||||
logx.WithDuration(duration).Infof("mongo(%s) - %s - ok - %s", c.name, method, string(content))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p keepablePromise) accept(err error) error {
|
||||
p.promise.Accept()
|
||||
p.log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p keepablePromise) keep(err error) error {
|
||||
if acceptable(err) {
|
||||
p.promise.Accept()
|
||||
} else {
|
||||
p.promise.Reject(err.Error())
|
||||
}
|
||||
|
||||
p.log(err)
|
||||
return err
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
return err == nil || err == mgo.ErrNotFound
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/mongo/internal"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
)
|
||||
|
||||
var errDummy = errors.New("dummy")
|
||||
|
||||
func TestKeepPromise_accept(t *testing.T) {
|
||||
p := new(mockPromise)
|
||||
kp := keepablePromise{
|
||||
promise: p,
|
||||
log: func(error) {},
|
||||
}
|
||||
assert.Nil(t, kp.accept(nil))
|
||||
assert.Equal(t, mgo.ErrNotFound, kp.accept(mgo.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestKeepPromise_keep(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
p := new(mockPromise)
|
||||
kp := keepablePromise{
|
||||
promise: p,
|
||||
log: func(error) {},
|
||||
}
|
||||
assert.Equal(t, test.err, kp.keep(test.err))
|
||||
assert.Equal(t, test.accepted, p.accepted)
|
||||
assert.Equal(t, test.reason, p.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCollection(t *testing.T) {
|
||||
col := newCollection(&mgo.Collection{
|
||||
Database: nil,
|
||||
Name: "foo",
|
||||
FullName: "bar",
|
||||
}, breaker.GetBreaker("localhost"))
|
||||
assert.Equal(t, "bar", col.(*decoratedCollection).name)
|
||||
}
|
||||
|
||||
func TestCollectionFind(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var query mgo.Query
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Find(gomock.Any()).Return(&query)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
actual := c.Find(nil)
|
||||
switch v := actual.(type) {
|
||||
case promisedQuery:
|
||||
assert.Equal(t, &query, v.Query)
|
||||
assert.Equal(t, errDummy, v.promise.keep(errDummy))
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
c.brk = new(dropBreaker)
|
||||
actual = c.Find(nil)
|
||||
assert.Equal(t, rejectedQuery{}, actual)
|
||||
}
|
||||
|
||||
func TestCollectionFindId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var query mgo.Query
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().FindId(gomock.Any()).Return(&query)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
actual := c.FindId(nil)
|
||||
switch v := actual.(type) {
|
||||
case promisedQuery:
|
||||
assert.Equal(t, &query, v.Query)
|
||||
assert.Equal(t, errDummy, v.promise.keep(errDummy))
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
c.brk = new(dropBreaker)
|
||||
actual = c.FindId(nil)
|
||||
assert.Equal(t, rejectedQuery{}, actual)
|
||||
}
|
||||
|
||||
func TestCollectionInsert(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Insert(nil, nil).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.Insert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.Insert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionPipe(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
var pipe mgo.Pipe
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Pipe(gomock.Any()).Return(&pipe)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
actual := c.Pipe(nil)
|
||||
switch v := actual.(type) {
|
||||
case promisedPipe:
|
||||
assert.Equal(t, &pipe, v.Pipe)
|
||||
assert.Equal(t, errDummy, v.promise.keep(errDummy))
|
||||
default:
|
||||
t.Fail()
|
||||
}
|
||||
c.brk = new(dropBreaker)
|
||||
actual = c.Pipe(nil)
|
||||
assert.Equal(t, rejectedPipe{}, actual)
|
||||
}
|
||||
|
||||
func TestCollectionRemove(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Remove(gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.Remove(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.Remove(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionRemoveAll(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().RemoveAll(gomock.Any()).Return(nil, errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
_, err := c.RemoveAll(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.RemoveAll(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionRemoveId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().RemoveId(gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.RemoveId(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.RemoveId(nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionUpdate(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Update(gomock.Any(), gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.Update(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.Update(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionUpdateId(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().UpdateId(gomock.Any(), gomock.Any()).Return(errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
err := c.UpdateId(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
err = c.UpdateId(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func TestCollectionUpsert(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
col.EXPECT().Upsert(gomock.Any(), gomock.Any()).Return(nil, errDummy)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
_, err := c.Upsert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
c.brk = new(dropBreaker)
|
||||
_, err = c.Upsert(nil, nil)
|
||||
assert.Equal(t, errDummy, err)
|
||||
}
|
||||
|
||||
func Test_logDuration(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
col := internal.NewMockMgoCollection(ctrl)
|
||||
c := decoratedCollection{
|
||||
collection: col,
|
||||
brk: breaker.NewBreaker(),
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
w := logx.NewWriter(&buf)
|
||||
o := logx.Reset()
|
||||
logx.SetWriter(w)
|
||||
|
||||
defer func() {
|
||||
logx.Reset()
|
||||
logx.SetWriter(o)
|
||||
}()
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", time.Millisecond, nil, "bar")
|
||||
assert.Contains(t, buf.String(), "foo")
|
||||
assert.Contains(t, buf.String(), "bar")
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", time.Millisecond, errors.New("bar"), make(chan int))
|
||||
assert.Contains(t, buf.String(), "bar")
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", slowThreshold.Load()+time.Millisecond, errors.New("bar"))
|
||||
assert.Contains(t, buf.String(), "bar")
|
||||
assert.Contains(t, buf.String(), "slowcall")
|
||||
|
||||
buf.Reset()
|
||||
c.logDuration("foo", slowThreshold.Load()+time.Millisecond, nil)
|
||||
assert.Contains(t, buf.String(), "foo")
|
||||
assert.Contains(t, buf.String(), "slowcall")
|
||||
}
|
||||
|
||||
type mockPromise struct {
|
||||
accepted bool
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *mockPromise) Accept() {
|
||||
p.accepted = true
|
||||
}
|
||||
|
||||
func (p *mockPromise) Reject(reason string) {
|
||||
p.reason = reason
|
||||
}
|
||||
|
||||
type dropBreaker struct{}
|
||||
|
||||
func (d *dropBreaker) Name() string {
|
||||
return "dummy"
|
||||
}
|
||||
|
||||
func (d *dropBreaker) Allow() (breaker.Promise, error) {
|
||||
return nil, errDummy
|
||||
}
|
||||
|
||||
func (d *dropBreaker) Do(req func() error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dropBreaker) DoWithAcceptable(req func() error, acceptable breaker.Acceptable) error {
|
||||
return errDummy
|
||||
}
|
||||
|
||||
func (d *dropBreaker) DoWithFallback(req func() error, fallback func(err error) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dropBreaker) DoWithFallbackAcceptable(req func() error, fallback func(err error) error,
|
||||
acceptable breaker.Acceptable) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:generate mockgen -package internal -destination collection_mock.go -source collection.go
|
||||
|
||||
package internal
|
||||
|
||||
import "github.com/globalsign/mgo"
|
||||
|
||||
// MgoCollection interface represents a mgo collection.
|
||||
type MgoCollection interface {
|
||||
Find(query interface{}) *mgo.Query
|
||||
FindId(id interface{}) *mgo.Query
|
||||
Insert(docs ...interface{}) error
|
||||
Pipe(pipeline interface{}) *mgo.Pipe
|
||||
Remove(selector interface{}) error
|
||||
RemoveAll(selector interface{}) (*mgo.ChangeInfo, error)
|
||||
RemoveId(id interface{}) error
|
||||
Update(selector, update interface{}) error
|
||||
UpdateId(id, update interface{}) error
|
||||
Upsert(selector, update interface{}) (*mgo.ChangeInfo, error)
|
||||
}
|
||||
@@ -1,181 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: collection.go
|
||||
|
||||
// Package internal is a generated GoMock package.
|
||||
package internal
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
mgo "github.com/globalsign/mgo"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockMgoCollection is a mock of MgoCollection interface
|
||||
type MockMgoCollection struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMgoCollectionMockRecorder
|
||||
}
|
||||
|
||||
// MockMgoCollectionMockRecorder is the mock recorder for MockMgoCollection
|
||||
type MockMgoCollectionMockRecorder struct {
|
||||
mock *MockMgoCollection
|
||||
}
|
||||
|
||||
// NewMockMgoCollection creates a new mock instance
|
||||
func NewMockMgoCollection(ctrl *gomock.Controller) *MockMgoCollection {
|
||||
mock := &MockMgoCollection{ctrl: ctrl}
|
||||
mock.recorder = &MockMgoCollectionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockMgoCollection) EXPECT() *MockMgoCollectionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Find mocks base method
|
||||
func (m *MockMgoCollection) Find(query interface{}) *mgo.Query {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Find", query)
|
||||
ret0, _ := ret[0].(*mgo.Query)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Find indicates an expected call of Find
|
||||
func (mr *MockMgoCollectionMockRecorder) Find(query interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockMgoCollection)(nil).Find), query)
|
||||
}
|
||||
|
||||
// FindId mocks base method
|
||||
func (m *MockMgoCollection) FindId(id interface{}) *mgo.Query {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FindId", id)
|
||||
ret0, _ := ret[0].(*mgo.Query)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FindId indicates an expected call of FindId
|
||||
func (mr *MockMgoCollectionMockRecorder) FindId(id interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindId", reflect.TypeOf((*MockMgoCollection)(nil).FindId), id)
|
||||
}
|
||||
|
||||
// Insert mocks base method
|
||||
func (m *MockMgoCollection) Insert(docs ...interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{}
|
||||
for _, a := range docs {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Insert", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Insert indicates an expected call of Insert
|
||||
func (mr *MockMgoCollectionMockRecorder) Insert(docs ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockMgoCollection)(nil).Insert), docs...)
|
||||
}
|
||||
|
||||
// Pipe mocks base method
|
||||
func (m *MockMgoCollection) Pipe(pipeline interface{}) *mgo.Pipe {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Pipe", pipeline)
|
||||
ret0, _ := ret[0].(*mgo.Pipe)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Pipe indicates an expected call of Pipe
|
||||
func (mr *MockMgoCollectionMockRecorder) Pipe(pipeline interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pipe", reflect.TypeOf((*MockMgoCollection)(nil).Pipe), pipeline)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockMgoCollection) Remove(selector interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Remove", selector)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockMgoCollectionMockRecorder) Remove(selector interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockMgoCollection)(nil).Remove), selector)
|
||||
}
|
||||
|
||||
// RemoveAll mocks base method
|
||||
func (m *MockMgoCollection) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveAll", selector)
|
||||
ret0, _ := ret[0].(*mgo.ChangeInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RemoveAll indicates an expected call of RemoveAll
|
||||
func (mr *MockMgoCollectionMockRecorder) RemoveAll(selector interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveAll", reflect.TypeOf((*MockMgoCollection)(nil).RemoveAll), selector)
|
||||
}
|
||||
|
||||
// RemoveId mocks base method
|
||||
func (m *MockMgoCollection) RemoveId(id interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveId", id)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveId indicates an expected call of RemoveId
|
||||
func (mr *MockMgoCollectionMockRecorder) RemoveId(id interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveId", reflect.TypeOf((*MockMgoCollection)(nil).RemoveId), id)
|
||||
}
|
||||
|
||||
// Update mocks base method
|
||||
func (m *MockMgoCollection) Update(selector, update interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Update", selector, update)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Update indicates an expected call of Update
|
||||
func (mr *MockMgoCollectionMockRecorder) Update(selector, update interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockMgoCollection)(nil).Update), selector, update)
|
||||
}
|
||||
|
||||
// UpdateId mocks base method
|
||||
func (m *MockMgoCollection) UpdateId(id, update interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateId", id, update)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateId indicates an expected call of UpdateId
|
||||
func (mr *MockMgoCollectionMockRecorder) UpdateId(id, update interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateId", reflect.TypeOf((*MockMgoCollection)(nil).UpdateId), id, update)
|
||||
}
|
||||
|
||||
// Upsert mocks base method
|
||||
func (m *MockMgoCollection) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Upsert", selector, update)
|
||||
ret0, _ := ret[0].(*mgo.ChangeInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Upsert indicates an expected call of Upsert
|
||||
func (mr *MockMgoCollectionMockRecorder) Upsert(selector, update interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockMgoCollection)(nil).Upsert), selector, update)
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
//go:generate mockgen -package mongo -destination iter_mock.go -source iter.go Iter
|
||||
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"github.com/globalsign/mgo/bson"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
// Iter interface represents a mongo iter.
|
||||
Iter interface {
|
||||
All(result interface{}) error
|
||||
Close() error
|
||||
Done() bool
|
||||
Err() error
|
||||
For(result interface{}, f func() error) error
|
||||
Next(result interface{}) bool
|
||||
State() (int64, []bson.Raw)
|
||||
Timeout() bool
|
||||
}
|
||||
|
||||
// A ClosableIter is a closable mongo iter.
|
||||
ClosableIter struct {
|
||||
Iter
|
||||
Cleanup func()
|
||||
}
|
||||
|
||||
promisedIter struct {
|
||||
Iter
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedIter struct{}
|
||||
)
|
||||
|
||||
func (i promisedIter) All(result interface{}) error {
|
||||
return i.promise.keep(i.Iter.All(result))
|
||||
}
|
||||
|
||||
func (i promisedIter) Close() error {
|
||||
return i.promise.keep(i.Iter.Close())
|
||||
}
|
||||
|
||||
func (i promisedIter) Err() error {
|
||||
return i.Iter.Err()
|
||||
}
|
||||
|
||||
func (i promisedIter) For(result interface{}, f func() error) error {
|
||||
var ferr error
|
||||
err := i.Iter.For(result, func() error {
|
||||
ferr = f()
|
||||
return ferr
|
||||
})
|
||||
if ferr == err {
|
||||
return i.promise.accept(err)
|
||||
}
|
||||
|
||||
return i.promise.keep(err)
|
||||
}
|
||||
|
||||
// Close closes a mongo iter.
|
||||
func (it *ClosableIter) Close() error {
|
||||
err := it.Iter.Close()
|
||||
it.Cleanup()
|
||||
return err
|
||||
}
|
||||
|
||||
func (i rejectedIter) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Close() error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Done() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i rejectedIter) Err() error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) For(result interface{}, f func() error) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (i rejectedIter) Next(result interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (i rejectedIter) State() (int64, []bson.Raw) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (i rejectedIter) Timeout() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: iter.go
|
||||
|
||||
// Package mongo is a generated GoMock package.
|
||||
package mongo
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
bson "github.com/globalsign/mgo/bson"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockIter is a mock of Iter interface
|
||||
type MockIter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockIterMockRecorder
|
||||
}
|
||||
|
||||
// MockIterMockRecorder is the mock recorder for MockIter
|
||||
type MockIterMockRecorder struct {
|
||||
mock *MockIter
|
||||
}
|
||||
|
||||
// NewMockIter creates a new mock instance
|
||||
func NewMockIter(ctrl *gomock.Controller) *MockIter {
|
||||
mock := &MockIter{ctrl: ctrl}
|
||||
mock.recorder = &MockIterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockIter) EXPECT() *MockIterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// All mocks base method
|
||||
func (m *MockIter) All(result interface{}) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "All", result)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// All indicates an expected call of All
|
||||
func (mr *MockIterMockRecorder) All(result interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "All", reflect.TypeOf((*MockIter)(nil).All), result)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockIter) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockIterMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockIter)(nil).Close))
|
||||
}
|
||||
|
||||
// Done mocks base method
|
||||
func (m *MockIter) Done() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Done")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Done indicates an expected call of Done
|
||||
func (mr *MockIterMockRecorder) Done() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockIter)(nil).Done))
|
||||
}
|
||||
|
||||
// Err mocks base method
|
||||
func (m *MockIter) Err() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Err")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Err indicates an expected call of Err
|
||||
func (mr *MockIterMockRecorder) Err() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockIter)(nil).Err))
|
||||
}
|
||||
|
||||
// For mocks base method
|
||||
func (m *MockIter) For(result interface{}, f func() error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "For", result, f)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// For indicates an expected call of For
|
||||
func (mr *MockIterMockRecorder) For(result, f interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "For", reflect.TypeOf((*MockIter)(nil).For), result, f)
|
||||
}
|
||||
|
||||
// Next mocks base method
|
||||
func (m *MockIter) Next(result interface{}) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Next", result)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Next indicates an expected call of Next
|
||||
func (mr *MockIterMockRecorder) Next(result interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockIter)(nil).Next), result)
|
||||
}
|
||||
|
||||
// State mocks base method
|
||||
func (m *MockIter) State() (int64, []bson.Raw) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "State")
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].([]bson.Raw)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// State indicates an expected call of State
|
||||
func (mr *MockIterMockRecorder) State() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "State", reflect.TypeOf((*MockIter)(nil).State))
|
||||
}
|
||||
|
||||
// Timeout mocks base method
|
||||
func (m *MockIter) Timeout() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Timeout")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Timeout indicates an expected call of Timeout
|
||||
func (mr *MockIterMockRecorder) Timeout() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Timeout", reflect.TypeOf((*MockIter)(nil).Timeout))
|
||||
}
|
||||
@@ -1,264 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
"github.com/zeromicro/go-zero/core/stringx"
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
func TestClosableIter_Close(t *testing.T) {
|
||||
errs := []error{
|
||||
nil,
|
||||
mgo.ErrNotFound,
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
cleaned := syncx.NewAtomicBool()
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Close().Return(err)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: func() {
|
||||
cleaned.Set(true)
|
||||
},
|
||||
}
|
||||
assert.Equal(t, err, ci.Close())
|
||||
assert.True(t, cleaned.True())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_AllAndClose(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().All(gomock.Any()).Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.All(nil))
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Close().Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.Close())
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_Err(t *testing.T) {
|
||||
errs := []error{
|
||||
nil,
|
||||
mgo.ErrNotFound,
|
||||
}
|
||||
|
||||
for _, err := range errs {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Err().Return(err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, err, pi.Err())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromisedIter_For(t *testing.T) {
|
||||
tests := []struct {
|
||||
err error
|
||||
accepted bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
err: nil,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: mgo.ErrNotFound,
|
||||
accepted: true,
|
||||
reason: "",
|
||||
},
|
||||
{
|
||||
err: errors.New("any"),
|
||||
accepted: false,
|
||||
reason: "any",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(stringx.RandId(), func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().For(gomock.Any(), gomock.Any()).Return(test.err)
|
||||
promise := new(mockPromise)
|
||||
pi := promisedIter{
|
||||
Iter: iter,
|
||||
promise: keepablePromise{
|
||||
promise: promise,
|
||||
log: func(error) {},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, test.err, pi.For(nil, nil))
|
||||
assert.Equal(t, test.accepted, promise.accepted)
|
||||
assert.Equal(t, test.reason, promise.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectedIter_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).All(nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_Close(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Close())
|
||||
}
|
||||
|
||||
func TestRejectedIter_Done(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Done())
|
||||
}
|
||||
|
||||
func TestRejectedIter_Err(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).Err())
|
||||
}
|
||||
|
||||
func TestRejectedIter_For(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedIter).For(nil, nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_Next(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Next(nil))
|
||||
}
|
||||
|
||||
func TestRejectedIter_State(t *testing.T) {
|
||||
n, raw := new(rejectedIter).State()
|
||||
assert.Equal(t, int64(0), n)
|
||||
assert.Nil(t, raw)
|
||||
}
|
||||
|
||||
func TestRejectedIter_Timeout(t *testing.T) {
|
||||
assert.False(t, new(rejectedIter).Timeout())
|
||||
}
|
||||
|
||||
func TestIter_Done(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Done().Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Done())
|
||||
}
|
||||
|
||||
func TestIter_Next(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Next(gomock.Any()).Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Next(nil))
|
||||
}
|
||||
|
||||
func TestIter_State(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().State().Return(int64(1), nil)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
n, raw := ci.State()
|
||||
assert.Equal(t, int64(1), n)
|
||||
assert.Nil(t, raw)
|
||||
}
|
||||
|
||||
func TestIter_Timeout(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iter := NewMockIter(ctrl)
|
||||
iter.EXPECT().Timeout().Return(true)
|
||||
ci := ClosableIter{
|
||||
Iter: iter,
|
||||
Cleanup: nil,
|
||||
}
|
||||
assert.True(t, ci.Timeout())
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
// A Model is a mongo model.
|
||||
type Model struct {
|
||||
session *concurrentSession
|
||||
db *mgo.Database
|
||||
collection string
|
||||
brk breaker.Breaker
|
||||
opts []Option
|
||||
}
|
||||
|
||||
// MustNewModel returns a Model, exits on errors.
|
||||
func MustNewModel(url, collection string, opts ...Option) *Model {
|
||||
model, err := NewModel(url, collection, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
// NewModel returns a Model.
|
||||
func NewModel(url, collection string, opts ...Option) (*Model, error) {
|
||||
session, err := getConcurrentSession(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Model{
|
||||
session: session,
|
||||
// If name is empty, the database name provided in the dialed URL is used instead
|
||||
db: session.DB(""),
|
||||
collection: collection,
|
||||
brk: breaker.GetBreaker(url),
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Find finds a record with given query.
|
||||
func (mm *Model) Find(query interface{}) (Query, error) {
|
||||
return mm.query(func(c Collection) Query {
|
||||
return c.Find(query)
|
||||
})
|
||||
}
|
||||
|
||||
// FindId finds a record with given id.
|
||||
func (mm *Model) FindId(id interface{}) (Query, error) {
|
||||
return mm.query(func(c Collection) Query {
|
||||
return c.FindId(id)
|
||||
})
|
||||
}
|
||||
|
||||
// GetCollection returns a Collection with given session.
|
||||
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
||||
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
|
||||
}
|
||||
|
||||
// Insert inserts docs into mm.
|
||||
func (mm *Model) Insert(docs ...interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Insert(docs...)
|
||||
})
|
||||
}
|
||||
|
||||
// Pipe returns a Pipe with given pipeline.
|
||||
func (mm *Model) Pipe(pipeline interface{}) (Pipe, error) {
|
||||
return mm.pipe(func(c Collection) Pipe {
|
||||
return c.Pipe(pipeline)
|
||||
})
|
||||
}
|
||||
|
||||
// PutSession returns the given session.
|
||||
func (mm *Model) PutSession(session *mgo.Session) {
|
||||
mm.session.putSession(session)
|
||||
}
|
||||
|
||||
// Remove removes the records with given selector.
|
||||
func (mm *Model) Remove(selector interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Remove(selector)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveAll removes all with given selector and returns a mgo.ChangeInfo.
|
||||
func (mm *Model) RemoveAll(selector interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
|
||||
return c.RemoveAll(selector)
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveId removes a record with given id.
|
||||
func (mm *Model) RemoveId(id interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.RemoveId(id)
|
||||
})
|
||||
}
|
||||
|
||||
// TakeSession gets a session.
|
||||
func (mm *Model) TakeSession() (*mgo.Session, error) {
|
||||
return mm.session.takeSession(mm.opts...)
|
||||
}
|
||||
|
||||
// Update updates a record with given selector.
|
||||
func (mm *Model) Update(selector, update interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.Update(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateId updates a record with given id.
|
||||
func (mm *Model) UpdateId(id, update interface{}) error {
|
||||
return mm.execute(func(c Collection) error {
|
||||
return c.UpdateId(id, update)
|
||||
})
|
||||
}
|
||||
|
||||
// Upsert upserts a record with given selector, and returns a mgo.ChangeInfo.
|
||||
func (mm *Model) Upsert(selector, update interface{}) (*mgo.ChangeInfo, error) {
|
||||
return mm.change(func(c Collection) (*mgo.ChangeInfo, error) {
|
||||
return c.Upsert(selector, update)
|
||||
})
|
||||
}
|
||||
|
||||
func (mm *Model) change(fn func(c Collection) (*mgo.ChangeInfo, error)) (*mgo.ChangeInfo, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) execute(fn func(c Collection) error) error {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session))
|
||||
}
|
||||
|
||||
func (mm *Model) pipe(fn func(c Collection) Pipe) (Pipe, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
func (mm *Model) query(fn func(c Collection) Query) (Query, error) {
|
||||
session, err := mm.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mm.PutSession(session)
|
||||
|
||||
return fn(mm.GetCollection(session)), nil
|
||||
}
|
||||
|
||||
// WithTimeout customizes an operation with given timeout.
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(opts *options) {
|
||||
opts.timeout = timeout
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
o := defaultOptions()
|
||||
WithTimeout(time.Second)(o)
|
||||
assert.Equal(t, time.Second, o.timeout)
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/syncx"
|
||||
)
|
||||
|
||||
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
|
||||
|
||||
type (
|
||||
options struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// Option defines the method to customize a mongo model.
|
||||
Option func(opts *options)
|
||||
)
|
||||
|
||||
// SetSlowThreshold sets the slow threshold.
|
||||
func SetSlowThreshold(threshold time.Duration) {
|
||||
slowThreshold.Set(threshold)
|
||||
}
|
||||
|
||||
func defaultOptions() *options {
|
||||
return &options{
|
||||
timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSetSlowThreshold(t *testing.T) {
|
||||
assert.Equal(t, defaultSlowThreshold, slowThreshold.Load())
|
||||
SetSlowThreshold(time.Second)
|
||||
assert.Equal(t, time.Second, slowThreshold.Load())
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
// Pipe interface represents a mongo pipe.
|
||||
Pipe interface {
|
||||
All(result interface{}) error
|
||||
AllowDiskUse() Pipe
|
||||
Batch(n int) Pipe
|
||||
Collation(collation *mgo.Collation) Pipe
|
||||
Explain(result interface{}) error
|
||||
Iter() Iter
|
||||
One(result interface{}) error
|
||||
SetMaxTime(d time.Duration) Pipe
|
||||
}
|
||||
|
||||
promisedPipe struct {
|
||||
*mgo.Pipe
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedPipe struct{}
|
||||
)
|
||||
|
||||
func (p promisedPipe) All(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.All(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) AllowDiskUse() Pipe {
|
||||
p.Pipe.AllowDiskUse()
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Batch(n int) Pipe {
|
||||
p.Pipe.Batch(n)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Collation(collation *mgo.Collation) Pipe {
|
||||
p.Pipe.Collation(collation)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p promisedPipe) Explain(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.Explain(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) Iter() Iter {
|
||||
return promisedIter{
|
||||
Iter: p.Pipe.Iter(),
|
||||
promise: p.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (p promisedPipe) One(result interface{}) error {
|
||||
return p.promise.keep(p.Pipe.One(result))
|
||||
}
|
||||
|
||||
func (p promisedPipe) SetMaxTime(d time.Duration) Pipe {
|
||||
p.Pipe.SetMaxTime(d)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) AllowDiskUse() Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Batch(n int) Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Collation(collation *mgo.Collation) Pipe {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Explain(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) Iter() Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
|
||||
func (p rejectedPipe) One(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (p rejectedPipe) SetMaxTime(d time.Duration) Pipe {
|
||||
return p
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
func TestRejectedPipe_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).All(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_AllowDiskUse(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.AllowDiskUse())
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Batch(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.Batch(1))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Collation(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.Collation(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Explain(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).Explain(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_Iter(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedPipe).Iter())
|
||||
}
|
||||
|
||||
func TestRejectedPipe_One(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedPipe).One(nil))
|
||||
}
|
||||
|
||||
func TestRejectedPipe_SetMaxTime(t *testing.T) {
|
||||
var p rejectedPipe
|
||||
assert.Equal(t, p, p.SetMaxTime(0))
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
// Query interface represents a mongo query.
|
||||
Query interface {
|
||||
All(result interface{}) error
|
||||
Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error)
|
||||
Batch(n int) Query
|
||||
Collation(collation *mgo.Collation) Query
|
||||
Comment(comment string) Query
|
||||
Count() (int, error)
|
||||
Distinct(key string, result interface{}) error
|
||||
Explain(result interface{}) error
|
||||
For(result interface{}, f func() error) error
|
||||
Hint(indexKey ...string) Query
|
||||
Iter() Iter
|
||||
Limit(n int) Query
|
||||
LogReplay() Query
|
||||
MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error)
|
||||
One(result interface{}) error
|
||||
Prefetch(p float64) Query
|
||||
Select(selector interface{}) Query
|
||||
SetMaxScan(n int) Query
|
||||
SetMaxTime(d time.Duration) Query
|
||||
Skip(n int) Query
|
||||
Snapshot() Query
|
||||
Sort(fields ...string) Query
|
||||
Tail(timeout time.Duration) Iter
|
||||
}
|
||||
|
||||
promisedQuery struct {
|
||||
*mgo.Query
|
||||
promise keepablePromise
|
||||
}
|
||||
|
||||
rejectedQuery struct{}
|
||||
)
|
||||
|
||||
func (q promisedQuery) All(result interface{}) error {
|
||||
return q.promise.keep(q.Query.All(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
info, err := q.Query.Apply(change, result)
|
||||
return info, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Batch(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Batch(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Collation(collation *mgo.Collation) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Collation(collation),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Comment(comment string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Comment(comment),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Count() (int, error) {
|
||||
v, err := q.Query.Count()
|
||||
return v, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Distinct(key string, result interface{}) error {
|
||||
return q.promise.keep(q.Query.Distinct(key, result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Explain(result interface{}) error {
|
||||
return q.promise.keep(q.Query.Explain(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) For(result interface{}, f func() error) error {
|
||||
var ferr error
|
||||
err := q.Query.For(result, func() error {
|
||||
ferr = f()
|
||||
return ferr
|
||||
})
|
||||
if ferr == err {
|
||||
return q.promise.accept(err)
|
||||
}
|
||||
|
||||
return q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) Hint(indexKey ...string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Hint(indexKey...),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Iter() Iter {
|
||||
return promisedIter{
|
||||
Iter: q.Query.Iter(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Limit(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Limit(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) LogReplay() Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.LogReplay(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
info, err := q.Query.MapReduce(job, result)
|
||||
return info, q.promise.keep(err)
|
||||
}
|
||||
|
||||
func (q promisedQuery) One(result interface{}) error {
|
||||
return q.promise.keep(q.Query.One(result))
|
||||
}
|
||||
|
||||
func (q promisedQuery) Prefetch(p float64) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Prefetch(p),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Select(selector interface{}) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Select(selector),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) SetMaxScan(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.SetMaxScan(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) SetMaxTime(d time.Duration) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.SetMaxTime(d),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Skip(n int) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Skip(n),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Snapshot() Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Snapshot(),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Sort(fields ...string) Query {
|
||||
return promisedQuery{
|
||||
Query: q.Query.Sort(fields...),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q promisedQuery) Tail(timeout time.Duration) Iter {
|
||||
return promisedIter{
|
||||
Iter: q.Query.Tail(timeout),
|
||||
promise: q.promise,
|
||||
}
|
||||
}
|
||||
|
||||
func (q rejectedQuery) All(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Apply(change mgo.Change, result interface{}) (*mgo.ChangeInfo, error) {
|
||||
return nil, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Batch(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Collation(collation *mgo.Collation) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Comment(comment string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Count() (int, error) {
|
||||
return 0, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Distinct(key string, result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Explain(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) For(result interface{}, f func() error) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Hint(indexKey ...string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Iter() Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Limit(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) LogReplay() Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) MapReduce(job *mgo.MapReduce, result interface{}) (*mgo.MapReduceInfo, error) {
|
||||
return nil, breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) One(result interface{}) error {
|
||||
return breaker.ErrServiceUnavailable
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Prefetch(p float64) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Select(selector interface{}) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) SetMaxScan(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) SetMaxTime(d time.Duration) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Skip(n int) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Snapshot() Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Sort(fields ...string) Query {
|
||||
return q
|
||||
}
|
||||
|
||||
func (q rejectedQuery) Tail(timeout time.Duration) Iter {
|
||||
return rejectedIter{}
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
func Test_rejectedQuery_All(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).All(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Apply(t *testing.T) {
|
||||
info, err := new(rejectedQuery).Apply(mgo.Change{}, nil)
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Batch(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Batch(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Collation(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Collation(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Comment(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Comment(""))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Count(t *testing.T) {
|
||||
n, err := new(rejectedQuery).Count()
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Distinct(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Distinct("", nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Explain(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).Explain(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_For(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).For(nil, nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Hint(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Hint())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Iter(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Iter())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Limit(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Limit(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_LogReplay(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.LogReplay())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_MapReduce(t *testing.T) {
|
||||
info, err := new(rejectedQuery).MapReduce(nil, nil)
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, err)
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_One(t *testing.T) {
|
||||
assert.Equal(t, breaker.ErrServiceUnavailable, new(rejectedQuery).One(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Prefetch(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Prefetch(1))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Select(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Select(nil))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_SetMaxScan(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.SetMaxScan(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_SetMaxTime(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.SetMaxTime(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Skip(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Skip(0))
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Snapshot(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Snapshot())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Sort(t *testing.T) {
|
||||
var q rejectedQuery
|
||||
assert.Equal(t, q, q.Sort())
|
||||
}
|
||||
|
||||
func Test_rejectedQuery_Tail(t *testing.T) {
|
||||
assert.EqualValues(t, rejectedIter{}, new(rejectedQuery).Tail(0))
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user